diff --git a/src/megatron/bridge/models/model_provider.py b/src/megatron/bridge/models/model_provider.py index 8c3f111731..d53e03773f 100644 --- a/src/megatron/bridge/models/model_provider.py +++ b/src/megatron/bridge/models/model_provider.py @@ -215,6 +215,7 @@ def initialize_model_parallel( virtual_pipeline_model_parallel_size=getattr(self, "virtual_pipeline_model_parallel_size", None), context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, + expert_tensor_parallel_size=getattr(self, "expert_tensor_parallel_size", None), **model_parallel_kwargs, ) if seed is not None: