From d86651d0d03fe089f2c156bd40a28b4fca5f21dc Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Thu, 4 Sep 2025 00:35:47 +1000 Subject: [PATCH 1/2] Pass expert_tensor_parallel_size to pstate init Signed-off-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> --- src/megatron/bridge/models/model_provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/megatron/bridge/models/model_provider.py b/src/megatron/bridge/models/model_provider.py index 8c3f111731..d59215a58c 100644 --- a/src/megatron/bridge/models/model_provider.py +++ b/src/megatron/bridge/models/model_provider.py @@ -215,6 +215,7 @@ def initialize_model_parallel( virtual_pipeline_model_parallel_size=getattr(self, "virtual_pipeline_model_parallel_size", None), context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, + expert_tensor_parallel_size=getattr(self, "tensor_model_parallel_size", None), **model_parallel_kwargs, ) if seed is not None: From 91fba7d1799f51c9dc51b1ad61a4d5a5fd2f8725 Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Thu, 4 Sep 2025 00:40:57 +1000 Subject: [PATCH 2/2] check correct key Signed-off-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> --- src/megatron/bridge/models/model_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/megatron/bridge/models/model_provider.py b/src/megatron/bridge/models/model_provider.py index d59215a58c..d53e03773f 100644 --- a/src/megatron/bridge/models/model_provider.py +++ b/src/megatron/bridge/models/model_provider.py @@ -215,7 +215,7 @@ def initialize_model_parallel( virtual_pipeline_model_parallel_size=getattr(self, "virtual_pipeline_model_parallel_size", None), context_parallel_size=getattr(self, "context_parallel_size", 1) or 1, expert_model_parallel_size=getattr(self, "expert_model_parallel_size", 1) or 1, - expert_tensor_parallel_size=getattr(self, "tensor_model_parallel_size", None), + expert_tensor_parallel_size=getattr(self, "expert_tensor_parallel_size", None), **model_parallel_kwargs, ) if seed is not None: