Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Generalize for multiple utterance features #1388

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions pytext/metric_reporters/word_tagging_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import Counter
from typing import Dict, List, NamedTuple

import torch
from pytext.common.constants import DatasetFieldName, Stage
from pytext.data import CommonMetadata
from pytext.metrics import (
Expand Down Expand Up @@ -103,6 +104,43 @@ def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
self.label_vocabs = label_vocabs
super().__init__(channels)

def add_batch_stats(
self, n_batches, preds, targets, scores, loss, m_input, **context
):
"""
Aggregates a batch of output data (predictions, scores, targets/true labels
and loss).

Args:
n_batches (int): number of current batch
preds (torch.Tensor): predictions of current batch
targets (torch.Tensor): targets of current batch
scores (torch.Tensor): scores of current batch
loss (double): average loss of current batch
m_input (Tuple[torch.Tensor, ...]): model inputs of current batch
context (Dict[str, Any]): any additional context data, it could be
either a list of data which maps to each example, or a single value
for the batch
"""
self.n_batches = n_batches
self.aggregate_preds(preds, context)
self.aggregate_targets(targets, context)
self.aggregate_scores(scores)
for key, val in context.items():
if not (isinstance(val, torch.Tensor) or isinstance(val, List)):
continue
if key not in self.all_context:
self.all_context[key] = []
self.aggregate_data(self.all_context[key], val)
if loss is not None:
self.all_loss.append(float(loss))
self.batch_size.append(len(m_input[-1]))

# realtime stats
if DatasetFieldName.NUM_TOKENS in context:
self.realtime_meters["tps"].update(context[DatasetFieldName.NUM_TOKENS])
self.realtime_meters["ups"].update(1)

@classmethod
def from_config(cls, config, tensorizers):
return MultiLabelSequenceTaggingMetricReporter(
Expand All @@ -129,6 +167,58 @@ def aggregate_targets(self, batch_targets, batch_context=None):
def aggregate_scores(self, batch_scores):
self.aggregate_tuple_data(self.all_scores, batch_scores)

def report_metric(
self,
model,
stage,
epoch,
reset=True,
print_to_channels=True,
optimizer=None,
privacy_engine=None, # to be handled by the subclassed metric reporters
):
"""
Calculate metrics and average loss, report all statistic data to channels

Args:
model (nn.Module): the PyTorch neural network model.
stage (Stage): training, evaluation or test
epoch (int): current epoch
reset (bool): if all data should be reset after report, default is True
print_to_channels (bool): if report data to channels, default is True
"""
self.gen_extra_context()
self.total_loss = self.calculate_loss()
metrics = self.calculate_metric()
model_select_metric = self.get_model_select_metric(metrics)

# print_to_channels is true only on gpu 0, but we need all gpus to sync
# metric
self.report_realtime_metric(stage)

if print_to_channels:
for channel in self.channels:
if stage in channel.stages:
channel.report(
stage,
epoch,
metrics,
model_select_metric,
self.total_loss,
self.predictions_to_report(),
self.targets_to_report(),
self.all_scores,
self.all_context,
self.get_meta(),
model,
optimizer,
)

if reset:
self._reset()
self._reset_realtime()
return metrics

def calculate_metric(self):
list_score_pred_expect = []
for label_idx, _ in enumerate(self.label_names):
Expand Down