Skip to content

Commit

Permalink
Remove compute_on_step from metrics (NVIDIA#6979)
Browse files Browse the repository at this point in the history
* Remove `compute_on_step` from metrics

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove confusing log message

Signed-off-by: smajumdar <[email protected]>

* Update tests

Signed-off-by: smajumdar <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and zhehuaichen committed Oct 4, 2023
1 parent 35ab449 commit 47efc37
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 32 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def validation_epoch_end(self, outputs):
def __init__(
self, decoding: RNNTDecoding, batch_dim_index=0, use_cer=False, log_prediction=True, dist_sync_on_step=False
):
super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super(RNNTWER, self).__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.batch_dim_index = batch_dim_index
self.use_cer = use_cer
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def __init__(
log_prediction: bool = True,
dist_sync_on_step=False,
):
super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super(RNNTBPEWER, self).__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.batch_dim_index = batch_dim_index
self.use_cer = use_cer
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ def __init__(
fold_consecutive=True,
dist_sync_on_step=False,
):
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.decoding = decoding
self.use_cer = use_cer
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(
fold_consecutive=True,
dist_sync_on_step=False,
):
super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.decoding = decoding
self.tokenizer = self.decoding.tokenizer
self.blank_id = self.decoding.tokenizer.tokenizer.vocab_size
Expand Down
9 changes: 2 additions & 7 deletions nemo/collections/common/metrics/global_average_loss_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class GlobalAverageLossMetric(Metric):
See :doc:`PyTorch Lightning Metrics<pytorch-lightning:metrics>` for the metric usage instruction.
Args:
compute_on_step:
The method :meth:`forward` only calls ``update()`` and returns ``None`` if this is set to ``False``.
default: ``True``
dist_sync_on_step:
Synchronize metric state across processes at each method :meth:`forward` call before returning the
value at the step
Expand All @@ -44,10 +41,8 @@ class GlobalAverageLossMetric(Metric):

full_state_update = True

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, take_avg_loss=True):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
)
def __init__(self, dist_sync_on_step=False, process_group=None, take_avg_loss=True):
super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group)
self.add_state("loss_sum", torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum')
self.add_state("num_measurements", torch.tensor(0, dtype=torch.int64), dist_reduce_fx='sum')
self.take_avg_loss = take_avg_loss
Expand Down
8 changes: 2 additions & 6 deletions nemo/collections/common/metrics/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Perplexity(Metric):
See `PyTorch Lightning Metrics <https://pytorch-lightning.readthedocs.io/en/stable/ecosystem/metrics.html>`_ for the metric usage instructions.
Args:
compute_on_step:
Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True``
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
Expand All @@ -44,10 +42,8 @@ class Perplexity(Metric):

full_state_update = True

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, validate_args=True):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
)
def __init__(self, dist_sync_on_step=False, process_group=None, validate_args=True):
super().__init__(dist_sync_on_step=dist_sync_on_step, process_group=process_group)
self.validate_args = validate_args
self.add_state('perplexities_sum', torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx='sum')
# Total number of distributions seen since last reset
Expand Down
9 changes: 2 additions & 7 deletions nemo/collections/nlp/metrics/sequence_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class SequencePerplexity(Metric):
See :doc:`PyTorch Lightning Metrics<pytorch-lightning:metrics>` for the metric usage instructions.
Args:
compute_on_step:
Forward only calls ``update()`` and returns ``None`` if this is set to ``False``. default: ``True``
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
process_group:
Expand All @@ -43,12 +41,9 @@ class SequencePerplexity(Metric):
to perform the allgather.
"""

def __init__(self, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None):
def __init__(self, dist_sync_on_step=False, process_group=None, dist_sync_fn=None):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn,
)

# Total sum of exponentiated average negative log likelihoods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# create extra bias

# setup to track metrics
self.validation_perplexity = Perplexity(compute_on_step=False)
self.validation_perplexity = Perplexity()

self.setup_optimization(cfg.optim)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
decoder=cfg.language_model.pretrained_decoder_model_name,
)

self.validation_perplexity = Perplexity(compute_on_step=False)
self.validation_perplexity = Perplexity()

self.setup_optimization(cfg.optim)

Expand Down
1 change: 0 additions & 1 deletion nemo/core/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
AVAILABLE_OPTIMIZERS['fused_adam'] = FusedAdam
except ModuleNotFoundError:
HAVE_APEX = False
logging.warning("Apex was not found. Using the lamb or fused_adam optimizer will error out.")

HAVE_APEX_DISTRIBUTED_ADAM = False
if HAVE_APEX:
Expand Down
8 changes: 3 additions & 5 deletions tests/collections/common/pl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _class_test(
calculated across devices for each batch (and not just at the end)
"""
# Instanciate lightning metric
metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
metric = metric_class(dist_sync_on_step=dist_sync_on_step, **metric_args)

# verify metrics work after being loaded from pickled state
pickled_metric = pickle.dumps(metric)
Expand Down Expand Up @@ -303,7 +303,7 @@ def _perplexity_class_test(
calculated across devices for each batch (and not just at the end)
"""
# Instanciate lightning metric
perplexity = Perplexity(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
perplexity = Perplexity(dist_sync_on_step=dist_sync_on_step, **metric_args)
if (probs is None) == (logits is None):
with pytest.raises(ValueError):
perplexity(probs, logits)
Expand Down Expand Up @@ -464,9 +464,7 @@ def _loss_class_test(
calculated across devices for each batch (and not just at the end)
"""
# Instantiate lightning metric
loss_metric = GlobalAverageLossMetric(
compute_on_step=True, dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss
)
loss_metric = GlobalAverageLossMetric(dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss)

# verify loss works after being loaded from pickled state
pickled_metric = pickle.dumps(loss_metric)
Expand Down

0 comments on commit 47efc37

Please sign in to comment.