Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions examples/speechlm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ python speech_to_text_llm_train.py \
++data.validation_ds.seed=10 \
++data.validation_ds.shard_seed="randomized" \
++data.validation_ds.shuffle=false \
data.validation_ds.metric.name='loss' \ # set to `loss` to only calculate validation loss w/o LLM decoding for faster validation
++data.validation_ds.force_iterable_dataset=true \ # set to true for mixing tarred and non-tarred data
++data.validation_ds.metric.name='loss' \ # set to `loss` to only calculate validation loss w/o LLM decoding for faster validation
++model.data.validation_ds.force_finite=true \
++model.data.validation_ds.force_map_dataset=true \
++trainer.use_distributed_sampler=false \
++trainer.limit_train_batches=2000 \
trainer.val_check_interval=2000 \ # set to same value as limit_train_batches
Expand Down Expand Up @@ -178,12 +179,21 @@ python speech_to_text_llm_validate.py \
++data.validation_ds.quadratic_duration=null \
++data.validation_ds.bucket_duration_bins=null \
++data.validation_ds.shuffle=false \
++model.data.validation_ds.force_finite=true \
++model.data.validation_ds.force_map_dataset=true \
++trainer.use_distributed_sampler=false \
++resume.resume_from_path=$CKPT_PATH \ # path to the checkpoint to load
++data.validation_ds.write_predictions_to_file=true \
++data.validation_ds.output_dir=$OUTPUT_DIR \ # directory to save the predictions
name="${CONFIG_NAME}_run1_eval" \
trainer.devices=1 \
data.common.tokens_to_generate=256 \
++model.inference_config.tokens_to_generate=256 \
++model.inference_config.temperature=1.0 \
++model.inference_config.top_k=50 \
++model.inference_config.top_p=0.95 \
++model.inference_config.greedy=false \ # set to `true` to use greedy decoding instead of sampling
++model.inference_config.repetition_penalty=1.0 \
~logger.wandb # remove wandb logger
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __getitem__(self, all_cuts: CutSet) -> dict[str, Union[torch.Tensor, list[st
def _get_metadata(self, all_cuts: CutSet) -> List[dict]:
metadata = []
for cut in all_cuts:
metadata.append({"type": type(cut).__name__, "id": getattr(cut, "id", "n/a")})
metadata.append({"type": type(cut).__name__, "id": getattr(cut, "id", "n/a"), "cut": str(cut)})
return metadata

def _process_sample(self, sample: Any) -> dict:
Expand Down
22 changes: 21 additions & 1 deletion nemo/collections/speechlm/models/speech_to_text_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,27 @@ def _reconfigure_and_process_inference_batch(self, batch, data_cfg):
)

def set_inference_config(self, inference_config: Optional[Dict] = None):
self._inference_config = dict(inference_config) if inference_config is not None else None
ALLOWED_KEYS = [
'tokens_to_generate',
'temperature',
'top_k',
'top_p',
'greedy',
'repetition_penalty',
'min_tokens_to_generate',
]
if inference_config is None:
return
if not isinstance(inference_config, dict):
inference_config = dict(inference_config)
for key in inference_config.keys():
if key not in ALLOWED_KEYS:
logging.warning(
f"inference_config key `{key}` is not in allowed keys ({ALLOWED_KEYS}), ignoring it..."
)
inference_config.pop(key)
self._inference_config = inference_config
logging.info(f"Setting inference config: {self._inference_config}")

def get_inference_config(self):
return dict(self._inference_config) if self._inference_config is not None else None
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/speechlm/recipes/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def build_components(cfg: DictConfig, tokenizer: Optional[AutoTokenizer] = None)
data_config=cfg['data']['common'],
resume_speech_model_from_path=cfg['model'].get('resume_speech_model_from_path', None),
resume_modality_adapter_from_path=cfg['model'].get('resume_modality_adapter_from_path', None),
inference_config=cfg['model'].get('inference_config', None),
)

if model_config.language_model_from_pretrained:
Expand Down
Loading