diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index 91fa4a6f92b5..110e59494b52 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -241,7 +241,7 @@ def inference_step_validation_call(self, batch, batch_idx, data_cfg, dataloader_ } return outputs - def gather_and_maybe_write_predictions(self, output, data_cfg, mode, dataloader_idx=0): + def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_metric, dataloader_idx=0): if not data_cfg.get("write_embeddings_to_file", False): return True gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())]