Skip to content

Commit

Permalink
Use torch get_rank() for broader coverage of use-cases
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl committed Mar 25, 2024
1 parent fad7b57 commit 6f1c0d1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions nemo/export/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _load_model(

self._check_ddp_initialized(model)

if is_global_rank_zero():
if dist.get_rank() == 0:
print(model)

return model
Expand Down Expand Up @@ -183,7 +183,7 @@ def quantize(

def forward_loop():
for i, batch in enumerate(dataloader):
if is_global_rank_zero():
if dist.get_rank() == 0:
print(f"Calibrating batch {i}")
model.predict_step(batch, i)

Expand Down Expand Up @@ -212,7 +212,7 @@ def export(self, model, model_save: str):
export_tensorrt_llm_config=self.export_config.export_tensorrt_llm_config,
)
dist.barrier() # Wait until all ranks complete export_model_config step
if is_global_rank_zero():
if dist.get_rank() == 0:
logging.info(f"Exporting quantized weights, model artifacts, and tokenizer config to {model_save}...")
save_artifacts(model, export_dir)
if save_qnemo:
Expand Down

0 comments on commit 6f1c0d1

Please sign in to comment.