From 6f1c0d12b6ae258d05fe4b8e5e3ea9a51d9c8322 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Mon, 25 Mar 2024 17:39:37 +0100 Subject: [PATCH] Use torch get_rank() for broader coverage of use-cases Signed-off-by: Jan Lasek --- nemo/export/quantize/quantizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py index d24e2a80babc2..bfdea9b10e1c1 100644 --- a/nemo/export/quantize/quantizer.py +++ b/nemo/export/quantize/quantizer.py @@ -128,7 +128,7 @@ def _load_model( self._check_ddp_initialized(model) - if is_global_rank_zero(): + if dist.get_rank() == 0: print(model) return model @@ -183,7 +183,7 @@ def quantize( def forward_loop(): for i, batch in enumerate(dataloader): - if is_global_rank_zero(): + if dist.get_rank() == 0: print(f"Calibrating batch {i}") model.predict_step(batch, i) @@ -212,7 +212,7 @@ def export(self, model, model_save: str): export_tensorrt_llm_config=self.export_config.export_tensorrt_llm_config, ) dist.barrier() # Wait until all ranks complete export_model_config step - if is_global_rank_zero(): + if dist.get_rank() == 0: logging.info(f"Exporting quantized weights, model artifacts, and tokenizer config to {model_save}...") save_artifacts(model, export_dir) if save_qnemo: