Skip to content

Commit

Permalink
[ASR] Add pretrained ASR models for Croatian
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Aug 5, 2022
1 parent 816ffda commit 452f18a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 7 deletions.
3 changes: 3 additions & 0 deletions docs/source/asr/data/benchmark_hr.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Model,Model Base Class,Model Card
stt_hr_conformer_ctc_large,EncDecCTCModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hr_conformer_ctc_large"
stt_hr_conformer_transducer_large,EncDecRNNTBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hr_conformer_transducer_large"
3 changes: 3 additions & 0 deletions docs/source/asr/data/scores/hr/conformer_hr.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Model Name,Language,ParlaSpeech-HR v1.0 (dev),ParlaSpeech-HR v1.0 (test)
stt_hr_conformer_ctc_large,hr,4.43,4.70
stt_hr_conformer_transducer_large,hr,4.56,4.69
10 changes: 10 additions & 0 deletions docs/source/asr/scores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,16 @@ FR

--------------------

HR
^^

.. csv-table::
:header-rows: 1
:align: left
:file: data/scores/hr/conformer_hr.csv

--------------------

IT
^^

Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,12 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_rw_conformer_ctc_large",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_rw_conformer_ctc_large/versions/1.11.0/files/stt_rw_conformer_ctc_large.nemo",
)

model = PretrainedModelInfo(
pretrained_model_name="stt_hr_conformer_ctc_large",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hr_conformer_ctc_large",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_hr_conformer_ctc_large/versions/1.11.0/files/stt_hr_conformer_ctc_large.nemo",
)
results.append(model)

return results
6 changes: 6 additions & 0 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def list_available_models(cls) -> List[PretrainedModelInfo]:
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_rw_conformer_transducer_large",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_rw_conformer_transducer_large/versions/1.11.0/files/stt_rw_conformer_transducer_large.nemo",
)

model = PretrainedModelInfo(
pretrained_model_name="stt_hr_conformer_transducer_large",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_hr_conformer_transducer_large",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_hr_conformer_transducer_large/versions/1.11.0/files/stt_hr_conformer_transducer_large.nemo",
)
results.append(model)

return results
Expand Down
18 changes: 11 additions & 7 deletions scripts/checkpoint_averaging/checkpoint_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,21 @@

from nemo.core import ModelPT
from nemo.utils import logging, model_utils
from tqdm.auto import tqdm


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'model_fname_list',
metavar='N',
metavar='NEMO_FILE_OR_FOLDER',
type=str,
nargs='+',
help='Input .nemo files (or folders who contains them) to parse',
)
parser.add_argument(
'--import_fname_list',
metavar='FILE',
type=str,
nargs='+',
default=[],
Expand All @@ -59,7 +61,7 @@ def main():
args = parser.parse_args()

logging.info(
f"\n\nIMPORTANT: Use --import_fname_list for all files that contain missing classes (AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\n\n"
f"\n\nIMPORTANT:\nIf you get the following error:\n\t(AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\nuse:\n\t--import_fname_list\nfor all files that contain missing classes.\n\n"
)

for fn in args.import_fname_list:
Expand All @@ -77,7 +79,7 @@ def main():
filter(lambda fn: not fn.endswith("-averaged.nemo"), glob.glob(os.path.join(model_fname, "*.nemo")))
)
if len(nemo_files) != 1:
raise RuntimeError(f"Expected only a single .nemo files but discovered {len(nemo_files)} .nemo files")
raise RuntimeError(f"Expected exactly one .nemo file but discovered {len(nemo_files)} .nemo files")

model_fname = nemo_files[0]

Expand Down Expand Up @@ -107,23 +109,25 @@ def main():

logging.info(f"Averaging {n} checkpoints ...")

for ix, path in enumerate(checkpoint_paths):
for ix, path in enumerate(tqdm(checkpoint_paths, total=n, desc='Averaging checkpoints')):
checkpoint = torch.load(path, map_location=device)

if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
else:
raise RuntimeError(f"Checkpoint from {path} does not include a state_dict.")

if ix == 0:
# Initial state
avg_state = checkpoint

logging.info(f"Initialized average state dict with checkpoint : {path}")
logging.info(f"Initialized average state dict with checkpoint:\n\t{path}")
else:
# Accumulated state
for k in avg_state:
avg_state[k] = avg_state[k] + checkpoint[k]

logging.info(f"Updated average state dict with state from checkpoint : {path}")
logging.info(f"Updated average state dict with state from checkpoint:\n\t{path}")

for k in avg_state:
if str(avg_state[k].dtype).startswith("torch.int"):
Expand All @@ -136,7 +140,7 @@ def main():
# restore merged weights into model
nemo_model.load_state_dict(avg_state, strict=True)
# Save model
logging.info(f"Saving average mdel to: {avg_model_fname}")
logging.info(f"Saving average model to:\n\t{avg_model_fname}")
nemo_model.save_to(avg_model_fname)


Expand Down

0 comments on commit 452f18a

Please sign in to comment.