Skip to content

Commit 48b8204

Browse files
Add support to perform "inference-only" without loading training data (NVIDIA#8640)
* Add support to perform "inference-only" without loading training data Hi, Currently, the MegatronSBERT model cannot run inference. Essentially, a user may not be able to simply load a trained .nemo checkpoint and run inference (forward()) function on it. This patch adds a try/except block to handle cases where training data is not specified Signed-off-by: Aditya Malte <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aditya Malte <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fba71c0 commit 48b8204

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

nemo/collections/nlp/models/information_retrieval/megatron_sbert_model.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,23 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
391391
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(label_smoothing=cfg.get('label_smoothing', 0.0))
392392
softmax_temp = cfg.get('softmax_temp', 0.05)
393393
self.scale = 1.0 / softmax_temp
394-
train_file_path = self.cfg.data.data_prefix
395-
with open(train_file_path) as f:
396-
train_data = json.load(f)
397-
398-
random_seed = 42
399-
set_seed(random_seed)
400-
random.shuffle(train_data)
401-
402-
self.train_data = train_data
394+
try:
395+
train_file_path = self.cfg.data.data_prefix
396+
with open(train_file_path) as f:
397+
train_data = json.load(f)
398+
399+
random_seed = 42
400+
set_seed(random_seed)
401+
random.shuffle(train_data)
402+
403+
self.train_data = train_data
404+
logging.warning("Model is running in training mode")
405+
except:
406+
logging.warning(
407+
"Model is running inference mode as training data is not specified, or could not be loaded"
408+
)
409+
random_seed = 42
410+
set_seed(random_seed)
403411

404412
def model_provider_func(self, pre_process, post_process):
405413
cfg = self.cfg

0 commit comments

Comments
 (0)