Skip to content

Commit

Permalink
Test only if the model was trained on single GPU for accurate results. (
Browse files Browse the repository at this point in the history
#3470)

* Test only if the model was trained on single GPU for accurate results.

Signed-off-by: smajumdar <[email protected]>

* Test only if the model was trained on single GPU for accurate results.

Signed-off-by: smajumdar <[email protected]>
  • Loading branch information
titu1994 authored and ericharper committed Jan 24, 2022
1 parent e998de4 commit 1d6ccb0
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 38 deletions.
11 changes: 2 additions & 9 deletions examples/asr/asr_ctc/speech_to_text_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,8 @@ def main(cfg):
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
test_trainer = pl.Trainer(
gpus=gpu,
precision=trainer.precision,
amp_level=trainer.accelerator_connector.amp_level,
amp_backend=cfg.trainer.get("amp_backend", "native"),
)
if asr_model.prepare_test(test_trainer):
test_trainer.test(asr_model)
if asr_model.prepare_test(trainer):
trainer.test(asr_model)


if __name__ == '__main__':
Expand Down
11 changes: 2 additions & 9 deletions examples/asr/asr_ctc/speech_to_text_ctc_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,8 @@ def main(cfg):
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
test_trainer = pl.Trainer(
gpus=gpu,
precision=trainer.precision,
amp_level=trainer.accelerator_connector.amp_level,
amp_backend=cfg.trainer.get("amp_backend", "native"),
)
if asr_model.prepare_test(test_trainer):
test_trainer.test(asr_model)
if asr_model.prepare_test(trainer):
trainer.test(asr_model)


if __name__ == '__main__':
Expand Down
11 changes: 2 additions & 9 deletions examples/asr/asr_transducer/speech_to_text_rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,8 @@ def main(cfg):
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
test_trainer = pl.Trainer(
gpus=gpu,
precision=trainer.precision,
amp_level=trainer.accelerator_connector.amp_level,
amp_backend=cfg.trainer.get("amp_backend", "native"),
)
if asr_model.prepare_test(test_trainer):
test_trainer.test(asr_model)
if asr_model.prepare_test(trainer):
trainer.test(asr_model)


if __name__ == '__main__':
Expand Down
11 changes: 2 additions & 9 deletions examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,8 @@ def main(cfg):
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
test_trainer = pl.Trainer(
gpus=gpu,
precision=trainer.precision,
amp_level=trainer.accelerator_connector.amp_level,
amp_backend=cfg.trainer.get("amp_backend", "native"),
)
if asr_model.prepare_test(test_trainer):
test_trainer.test(asr_model)
if asr_model.prepare_test(trainer):
trainer.test(asr_model)


if __name__ == '__main__':
Expand Down
2 changes: 0 additions & 2 deletions examples/asr/speech_classification/speech_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def main(cfg):
trainer.fit(asr_model)

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
trainer = pl.Trainer(gpus=gpu)
if asr_model.prepare_test(trainer):
trainer.test(asr_model)

Expand Down

0 comments on commit 1d6ccb0

Please sign in to comment.