From b440b0a910b27681567cca8ec3b56a2bad08b403 Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Fri, 5 Sep 2025 07:45:29 +1000 Subject: [PATCH] Explicitly pass `expert_tensor_parallel_size` to `initialize_model_parallel` (#537) * Pass expert_tensor_parallel_size to pstate init Signed-off-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> * check correct key Signed-off-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> --------- Signed-off-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> --- src/megatron/bridge/models/model_provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/megatron/bridge/models/model_provider.py b/src/megatron/bridge/models/model_provider.py index 8c3f111731..d53e03773f 100644 --- a/src/megatron/bridge/models/model_provider.py +++ b/src/megatron/bridge/models/model_provider.py @@ -215,6 +215,7 @@ def initialize_model_parallel( virtual_pipeline_model_parallel_size=getattr(self, "virtual_pipeline_model_parallel_size", None), context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, + expert_tensor_parallel_size=getattr(self, "expert_tensor_parallel_size", None), **model_parallel_kwargs, ) if seed is not None: