diff --git a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb index f94f102abf8e..5580bc4cf946 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb @@ -936,20 +936,22 @@ "outputs": [], "source": [ "# let's reload our pretrained model\n", - "pretrained_model = nemo_nlp.models.PunctuationCapitalizationLexicalAudioModel.from_pretrained('Punctuation_And_Capitalization_Lexical_Audio/checkpoints/Punctuation_and_Capitalization_Lexical_Audio.nemo')\n", + "pretrained_model = nemo_nlp.models.PunctuationCapitalizationLexicalAudioModel.restore_from('Punctuation_And_Capitalization_Lexical_Audio/checkpoints/Punctuation_and_Capitalization_Lexical_Audio.nemo')\n", "\n", "# setup train and validation Pytorch DataLoaders\n", "pretrained_model.update_config_after_restoring_from_checkpoint(\n", " train_ds={\n", " 'ds_item': DATA_DIR,\n", - " 'text_file': 'text_train.txt',\n", - " 'labels_file': 'labels_train.txt',\n", + " 'text_file': 'text_dev.txt',\n", + " 'labels_file': 'labels_dev.txt',\n", + " 'audio_file': 'audio_dev.txt',\n", " 'tokens_in_batch': 1024,\n", " },\n", " validation_ds={\n", " 'ds_item': DATA_DIR,\n", " 'text_file': 'text_dev.txt',\n", " 'labels_file': 'labels_dev.txt',\n", + " 'audio_file': 'audio_dev.txt',\n", " 'tokens_in_batch': 1024,\n", " },\n", ")\n", @@ -960,8 +962,8 @@ "fast_dev_run = True\n", "trainer = pl.Trainer(devices=1, accelerator='gpu', fast_dev_run=fast_dev_run)\n", "pretrained_model.set_trainer(trainer)\n", - "pretrained_model.setup_training_data()\n", - "pretrained_model.setup_validation_data()\n", + "pretrained_model.setup_training_data(pretrained_model.cfg.train_ds)\n", + "pretrained_model.setup_validation_data(pretrained_model.cfg.validation_ds)\n", "trainer.fit(pretrained_model)" ] }, @@ -997,7 +999,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.8.13" } }, "nbformat": 4,