Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions nemo_deploy/llm/inference/inference_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,20 @@ def setup_megatron_model_and_tokenizer_for_inference(
dist_config = DistributedInitConfig(distributed_backend="nccl")
torch_distributed_init(dist_config)
model_config, mlm_args = load_model_config(checkpoint_path)

# Convert attention_backend from string to enum if needed
if hasattr(model_config, "attention_backend") and isinstance(model_config.attention_backend, str):
if model_config.attention_backend == "AttnBackend.fused":
model_config.attention_backend = AttnBackend.fused
elif model_config.attention_backend == "AttnBackend.flash":
model_config.attention_backend = AttnBackend.flash
elif model_config.attention_backend == "AttnBackend.unfused":
model_config.attention_backend = AttnBackend.unfused
elif model_config.attention_backend == "AttnBackend.local":
model_config.attention_backend = AttnBackend.local
elif model_config.attention_backend == "AttnBackend.auto":
model_config.attention_backend = AttnBackend.auto

if tensor_model_parallel_size is not None:
model_config.tensor_model_parallel_size = tensor_model_parallel_size
if pipeline_model_parallel_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion nemo_deploy/llm/inference/tron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _initialize_tp_communicators(model_config: Union[GPTConfig, T5Config], micro
"Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and 'transformer_engine' packages"
)

if model_config.tp_comm_overlap_cfg is not None:
if hasattr(model_config, "tp_comm_overlap_cfg") and model_config.tp_comm_overlap_cfg is not None:
with open(model_config.tp_comm_overlap_cfg, "r") as stream:
ub_cfgs = yaml.safe_load(stream)
else:
Expand Down
Loading