Skip to content

Commit

Permalink
tutorial fixes (NVIDIA#5354) (NVIDIA#5361)
Browse files Browse the repository at this point in the history
Signed-off-by: Matvei Novikov <[email protected]>

Signed-off-by: Matvei Novikov <[email protected]>

Signed-off-by: Matvei Novikov <[email protected]>
Co-authored-by: Matvei Novikov <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 29, 2022
1 parent 4251963 commit 48d389f
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -936,20 +936,22 @@
"outputs": [],
"source": [
"# let's reload our pretrained model\n",
"pretrained_model = nemo_nlp.models.PunctuationCapitalizationLexicalAudioModel.from_pretrained('Punctuation_And_Capitalization_Lexical_Audio/checkpoints/Punctuation_and_Capitalization_Lexical_Audio.nemo')\n",
"pretrained_model = nemo_nlp.models.PunctuationCapitalizationLexicalAudioModel.restore_from('Punctuation_And_Capitalization_Lexical_Audio/checkpoints/Punctuation_and_Capitalization_Lexical_Audio.nemo')\n",
"\n",
"# setup train and validation Pytorch DataLoaders\n",
"pretrained_model.update_config_after_restoring_from_checkpoint(\n",
" train_ds={\n",
" 'ds_item': DATA_DIR,\n",
" 'text_file': 'text_train.txt',\n",
" 'labels_file': 'labels_train.txt',\n",
" 'text_file': 'text_dev.txt',\n",
" 'labels_file': 'labels_dev.txt',\n",
" 'audio_file': 'audio_dev.txt',\n",
" 'tokens_in_batch': 1024,\n",
" },\n",
" validation_ds={\n",
" 'ds_item': DATA_DIR,\n",
" 'text_file': 'text_dev.txt',\n",
" 'labels_file': 'labels_dev.txt',\n",
" 'audio_file': 'audio_dev.txt',\n",
" 'tokens_in_batch': 1024,\n",
" },\n",
")\n",
Expand All @@ -960,8 +962,8 @@
"fast_dev_run = True\n",
"trainer = pl.Trainer(devices=1, accelerator='gpu', fast_dev_run=fast_dev_run)\n",
"pretrained_model.set_trainer(trainer)\n",
"pretrained_model.setup_training_data()\n",
"pretrained_model.setup_validation_data()\n",
"pretrained_model.setup_training_data(pretrained_model.cfg.train_ds)\n",
"pretrained_model.setup_validation_data(pretrained_model.cfg.validation_ds)\n",
"trainer.fit(pretrained_model)"
]
},
Expand Down Expand Up @@ -997,7 +999,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 48d389f

Please sign in to comment.