diff --git a/src/megatron/bridge/models/conversion/peft_bridge.py b/src/megatron/bridge/models/conversion/peft_bridge.py index fa69313653..c61c0b47c3 100644 --- a/src/megatron/bridge/models/conversion/peft_bridge.py +++ b/src/megatron/bridge/models/conversion/peft_bridge.py @@ -15,6 +15,7 @@ from __future__ import annotations import itertools +import re from collections import defaultdict from dataclasses import dataclass from string import digits @@ -259,6 +260,14 @@ def _infer_qkv_projection_from_name(self, hf_name: str) -> Optional[str]: return "v_proj" return None + def _infer_hf_expert_idx(self, hf_name: str) -> Optional[int]: + """Return the expert index embedded in an HF MoE weight name.""" + + match = re.search(r"\bexperts\.(\d+)\b", hf_name) + if match is None: + return None + return int(match.group(1)) + def _split_qkv_linear_out_weight( self, megatron_model: Union[MegatronModel, List[MegatronModel]], @@ -270,6 +279,70 @@ def _split_qkv_linear_out_weight( q_out, k_out, v_out = split_qkv_weights(model.config, linear_out_weight) return {"q_proj": q_out, "k_proj": k_out, "v_proj": v_out} + def _split_fused_fc1_linear_out_weight( + self, + linear_out_weight: torch.Tensor, + *, + is_expert: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Split fused FC1 LoRA linear_out into gate/up with TP-aware ordering.""" + + tp_size = ( + parallel_state.get_expert_tensor_parallel_world_size() + if is_expert + else parallel_state.get_tensor_model_parallel_world_size() + ) + if tp_size <= 1: + return torch.chunk(linear_out_weight, 2, dim=0) + + shard_size = linear_out_weight.shape[0] // tp_size + if shard_size * tp_size != linear_out_weight.shape[0] or shard_size % 2 != 0: + return torch.chunk(linear_out_weight, 2, dim=0) + + shards = torch.split(linear_out_weight, shard_size, dim=0) + gate_parts = [] + up_parts = [] + for shard in shards: + gate_shard, up_shard = torch.chunk(shard, 2, dim=0) + gate_parts.append(gate_shard) + up_parts.append(up_shard) + gate = torch.cat(gate_parts, dim=0) + up = torch.cat(up_parts, dim=0) + return gate, up + + def _gather_expert_adapter_weight( + self, + weight: torch.Tensor, + ) -> Optional[List[torch.Tensor]]: + """Gather expert-sharded adapter weights across EP ranks when needed.""" + ep_size = parallel_state.get_expert_model_parallel_world_size() + if ep_size <= 1: + return None + assert weight.ndim < 3 + + gathered = [torch.empty_like(weight) for _ in range(ep_size)] + torch.distributed.all_gather(gathered, weight, group=parallel_state.get_expert_model_parallel_group()) + return gathered + + def _select_expert_adapter_weight( + self, + weight: torch.Tensor, + gathered: List[torch.Tensor], + expert_idx: int, + num_experts: int, + ) -> torch.Tensor: + """Select the per-expert adapter weight slice if present.""" + + assert weight.ndim < 3 + + ep_size = parallel_state.get_expert_model_parallel_world_size() + if ep_size <= 1: + return weight + + num_experts_per_rank = num_experts // ep_size + rank = expert_idx // num_experts_per_rank + return gathered[rank] + def _megatron_global_adapters_info_all_pp_ranks( self, megatron_model: Union[MegatronModel, List[MegatronModel]] ) -> List[tuple[str, str, bool, bool, int, int, int, int]]: @@ -379,7 +452,7 @@ def build_adapter_conversion_tasks( local_linear_in_name, local_linear_out_name = global_linear_in_name, global_linear_out_name base_suffix = ".weight" - if is_expert_linear(global_base_prefix): + if is_expert_linear(global_base_prefix) and ".local_experts." not in global_base_prefix: # To get expert layer hf mapping properly base_suffix = ".weight0" @@ -514,15 +587,44 @@ def stream_adapter_weights_megatron_to_hf( linear_in_tensor = adapter_weight.linear_in_weight.weight linear_out_tensor = adapter_weight.linear_out_weight.weight - if cpu: - linear_in_tensor = linear_in_tensor.cpu() - linear_out_tensor = linear_out_tensor.cpu() + is_expert = is_expert_linear(adapter_task.global_base_prefix) + is_grouped_expert = is_expert and ".local_experts." not in adapter_task.global_base_prefix + expert_linear_in_gathered = None + expert_linear_out_gathered = None + if is_grouped_expert: + expert_linear_in_gathered = self._gather_expert_adapter_weight( + linear_in_tensor, + ) + expert_linear_out_gathered = self._gather_expert_adapter_weight( + linear_out_tensor, + ) base_suffixes = [".weight"] - if is_expert_linear(adapter_task.global_base_prefix): + if is_grouped_expert: base_suffixes = [f".weight{expert_num}" for expert_num in range(num_moe_experts)] for base_suffix in base_suffixes: + current_linear_in_tensor = linear_in_tensor + current_linear_out_tensor = linear_out_tensor + if is_grouped_expert: + expert_idx = int(base_suffix[len(".weight") :]) + current_linear_in_tensor = self._select_expert_adapter_weight( + linear_in_tensor, + expert_linear_in_gathered, + expert_idx, + num_moe_experts, + ) + current_linear_out_tensor = self._select_expert_adapter_weight( + linear_out_tensor, + expert_linear_out_gathered, + expert_idx, + num_moe_experts, + ) + + if cpu: + current_linear_in_tensor = current_linear_in_tensor.cpu() + current_linear_out_tensor = current_linear_out_tensor.cpu() + base_hf_weight_names = self._get_base_hf_weight_names_for_adapter( mapping_registry, adapter_task.global_base_prefix, @@ -541,25 +643,27 @@ def stream_adapter_weights_megatron_to_hf( per_base_linear_out = self._get_fused_adapter_linear_out_slices( megatron_model, base_hf_weight_names, - linear_out_tensor, + current_linear_out_tensor, + is_expert=is_expert_linear(adapter_task.global_base_prefix), ) if per_base_linear_out is not None: for index, base_name in enumerate(base_hf_weight_names): - current_linear_out = per_base_linear_out.get(base_name) - assert current_linear_out is not None, "unknown projection name" + current_linear_out_tensor = per_base_linear_out.get(base_name) + assert current_linear_out_tensor is not None, "unknown projection name" - yield HFWeightTuple(linear_in_hf_names[index], linear_in_tensor) - yield HFWeightTuple(linear_out_hf_names[index], current_linear_out) + yield HFWeightTuple(linear_in_hf_names[index], current_linear_in_tensor) + yield HFWeightTuple(linear_out_hf_names[index], current_linear_out_tensor) continue - yield HFWeightTuple(linear_in_hf_names[0], linear_in_tensor) - yield HFWeightTuple(linear_out_hf_names[0], linear_out_tensor) + yield HFWeightTuple(linear_in_hf_names[0], current_linear_in_tensor) + yield HFWeightTuple(linear_out_hf_names[0], current_linear_out_tensor) def _get_fused_adapter_linear_out_slices( self, megatron_model: List[MegatronModel], base_hf_weight_names: List[str], linear_out_tensor: torch.Tensor, + is_expert: bool = False, ) -> Optional[Dict[str, torch.Tensor]]: """Return per-base-name linear_out slices for fused adapters, else None. @@ -580,13 +684,16 @@ def _get_fused_adapter_linear_out_slices( is_fused_fc1 = self._is_fused_fc1_gate_up(base_hf_weight_names, linear_out_tensor) if is_fused_fc1: - split_size = linear_out_tensor.shape[0] // 2 + gate_weight, up_weight = self._split_fused_fc1_linear_out_weight( + linear_out_tensor, + is_expert=is_expert, + ) per_base = {} for base_name in base_hf_weight_names: if "gate_proj" in base_name: - per_base[base_name] = linear_out_tensor[:split_size, :] + per_base[base_name] = gate_weight elif "up_proj" in base_name: - per_base[base_name] = linear_out_tensor[split_size:, :] + per_base[base_name] = up_weight else: raise ValueError(f"Unknown fused-fc1 base weight name: {base_name}") return per_base @@ -604,7 +711,7 @@ def _merge_lora_adapter_weights( if len(adapter_weights) > 1 and all( w.adapter_key in ADAPTER_NAME_MAP.values() for w in adapter_weights if w.adapter_key ): - return self._merge_canonical_adapter_from_weights(converted_weights_dict, adapter_weights) + return self._merge_canonical_adapter_from_weights(megatron_model, converted_weights_dict, adapter_weights) assert len(adapter_weights) == 1, "Expected a single adapter weight for standard LoRA merging" @@ -612,23 +719,56 @@ def _merge_lora_adapter_weights( alpha, dim = adapter_weight.alpha, adapter_weight.dim linear_in_weight = adapter_weight.linear_in_weight.weight linear_out_weight = adapter_weight.linear_out_weight.weight + num_moe_experts = megatron_model[0].config.num_moe_experts + is_expert = is_expert_linear(adapter_weight.global_base_prefix) + is_grouped_expert = is_expert and ".local_experts." not in adapter_weight.global_base_prefix + expert_linear_in_gathered = None + expert_linear_out_gathered = None + if is_grouped_expert: + expert_linear_in_gathered = self._gather_expert_adapter_weight(linear_in_weight) + expert_linear_out_gathered = self._gather_expert_adapter_weight(linear_out_weight) base_weight_shape = next(iter(converted_weights_dict.values())).shape weight_names = converted_weights_dict.keys() is_fused_fc1 = self._is_fused_fc1_gate_up(weight_names, linear_out_weight, base_weight_shape) - is_fused_qkv = self._is_fused_qkv(weight_names) + is_fused_qkv = self._is_fused_qkv(weight_names) and not is_expert qkv_linear_out_weights = ( self._split_qkv_linear_out_weight(megatron_model, linear_out_weight) if is_fused_qkv else None ) + fc1_gate_weight = fc1_up_weight = None + if is_fused_fc1 and not is_expert: + fc1_gate_weight, fc1_up_weight = self._split_fused_fc1_linear_out_weight( + linear_out_weight, + is_expert=is_expert, + ) for hf_name, base_weight in list(converted_weights_dict.items()): + current_linear_in_weight = linear_in_weight current_linear_out_weight = linear_out_weight + if is_grouped_expert: + expert_idx = self._infer_hf_expert_idx(hf_name) + current_linear_in_weight = self._select_expert_adapter_weight( + linear_in_weight, + expert_linear_in_gathered, + expert_idx, + num_moe_experts, + ) + current_linear_out_weight = self._select_expert_adapter_weight( + linear_out_weight, + expert_linear_out_gathered, + expert_idx, + num_moe_experts, + ) if is_fused_fc1: - split_size = linear_out_weight.shape[0] // 2 + if is_expert: + fc1_gate_weight, fc1_up_weight = self._split_fused_fc1_linear_out_weight( + current_linear_out_weight, + is_expert=is_expert, + ) if "gate_proj" in hf_name: - current_linear_out_weight = linear_out_weight[:split_size, :] + current_linear_out_weight = fc1_gate_weight elif "up_proj" in hf_name: - current_linear_out_weight = linear_out_weight[split_size:, :] + current_linear_out_weight = fc1_up_weight else: raise ValueError(f"Unknown weight name: {hf_name}") elif is_fused_qkv and qkv_linear_out_weights is not None: @@ -638,7 +778,7 @@ def _merge_lora_adapter_weights( current_linear_out_weight = qkv_linear_out_weights[projection_key] merged_weight = self._merge_single_adapter_weight( - base_weight, alpha, dim, linear_in_weight, current_linear_out_weight + base_weight, alpha, dim, current_linear_in_weight, current_linear_out_weight ) converted_weights_dict[hf_name] = merged_weight @@ -672,29 +812,63 @@ def _merge_single_adapter_weight( def _merge_canonical_adapter_from_weights( self, + megatron_model: List[MegatronModel], converted_weights_dict: Dict[str, torch.Tensor], adapter_weights: List[AdapterWeight], ) -> Dict[str, torch.Tensor]: """Merge CanonicalLoRA adapters using pre-materialized adapter weights.""" adapter_lookup = {aw.adapter_key: aw for aw in adapter_weights} + expert_linear_in_gathered: Dict[str, List[torch.Tensor]] = {} + expert_linear_out_gathered: Dict[str, List[torch.Tensor]] = {} + base_prefix = adapter_weights[0].global_base_prefix + num_moe_experts = megatron_model[0].config.num_moe_experts + is_expert = is_expert_linear(base_prefix) + is_grouped_expert = is_expert and ".local_experts." not in base_prefix + if is_grouped_expert: + for adapter_key, adapter_weight in adapter_lookup.items(): + expert_linear_in_gathered[adapter_key] = self._gather_expert_adapter_weight( + adapter_weight.linear_in_weight.weight, + ) + expert_linear_out_gathered[adapter_key] = self._gather_expert_adapter_weight( + adapter_weight.linear_out_weight.weight, + ) for hf_name, base_weight in converted_weights_dict.items(): target_adapter = None + target_adapter_key = None for suffix, adapter_key in ADAPTER_NAME_MAP.items(): if hf_name.endswith(suffix): target_adapter = adapter_lookup.get(adapter_key) + target_adapter_key = adapter_key break if target_adapter is None: raise ValueError(f"Adapter name mapping not found for {hf_name}") + linear_in_weight = target_adapter.linear_in_weight.weight + linear_out_weight = target_adapter.linear_out_weight.weight + if is_grouped_expert: + expert_idx = self._infer_hf_expert_idx(hf_name) + linear_in_weight = self._select_expert_adapter_weight( + linear_in_weight, + expert_linear_in_gathered.get(target_adapter_key), + expert_idx, + num_moe_experts, + ) + linear_out_weight = self._select_expert_adapter_weight( + linear_out_weight, + expert_linear_out_gathered.get(target_adapter_key), + expert_idx, + num_moe_experts, + ) + merged_weight = self._merge_single_adapter_weight( base_weight, target_adapter.alpha, target_adapter.dim, - target_adapter.linear_in_weight.weight, - target_adapter.linear_out_weight.weight, + linear_in_weight, + linear_out_weight, ) converted_weights_dict[hf_name] = merged_weight diff --git a/tests/unit_tests/models/test_auto_bridge.py b/tests/unit_tests/models/test_auto_bridge.py index 01d1c4b92e..32598748ed 100644 --- a/tests/unit_tests/models/test_auto_bridge.py +++ b/tests/unit_tests/models/test_auto_bridge.py @@ -545,7 +545,9 @@ def test_save_hf_pretrained(self, mock_is_init, mock_is_avail, mock_barrier, moc # Check artifacts were saved on rank 0 mock_hf_model.save_artifacts.assert_called_once_with("./output_dir", original_source_path=None) - mock_save_hf_weights.assert_called_once_with(mock_megatron_model, "./output_dir", True, True) + mock_save_hf_weights.assert_called_once_with( + mock_megatron_model, "./output_dir", True, True, merge_adapter_weights=True + ) @patch("torch.distributed.get_rank", return_value=1) @patch("torch.distributed.is_initialized", return_value=True) @@ -566,7 +568,9 @@ def test_save_hf_pretrained_non_zero_rank( # Artifacts should NOT be saved on non-zero rank mock_hf_model.save_artifacts.assert_not_called() - mock_save_hf_weights.assert_called_once_with(mock_megatron_model, "./output_dir", True, True) + mock_save_hf_weights.assert_called_once_with( + mock_megatron_model, "./output_dir", True, True, merge_adapter_weights=True + ) def test_export_hf_weights(self): """Test exporting weights from Megatron to HF format.""" diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index e38f1c7b69..02e1a98240 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -49,7 +49,11 @@ def test_merge_lora_adapter_weights_merges(monkeypatch): linear_out_weight=MegatronWeightTuple("out", torch.eye(4), vp_stage=0), ) - updated = bridge._merge_lora_adapter_weights([Mock(config=SimpleNamespace())], converted, [adapter_weight]) + updated = bridge._merge_lora_adapter_weights( + [Mock(config=SimpleNamespace(num_moe_experts=0))], + converted, + [adapter_weight], + ) expected = base_weight + torch.eye(4) torch.testing.assert_close(updated["hf.weight"], expected) @@ -85,11 +89,58 @@ def test_merge_lora_adapter_weights_fused_fc1(monkeypatch): linear_out_weight=MegatronWeightTuple("out", linear_out, vp_stage=0), ) - updated = bridge._merge_lora_adapter_weights([Mock(config=SimpleNamespace())], converted, [adapter_weight]) + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_tensor_model_parallel_world_size", + lambda: 1, + ) + + updated = bridge._merge_lora_adapter_weights( + [Mock(config=SimpleNamespace(num_moe_experts=0))], + converted, + [adapter_weight], + ) torch.testing.assert_close(updated["decoder.layers.0.mlp.gate_proj.weight"], torch.eye(4)) torch.testing.assert_close(updated["decoder.layers.0.mlp.up_proj.weight"], 2 * torch.eye(4)) +def test_merge_lora_adapter_weights_fused_fc1_tp_aware(monkeypatch): + bridge = DummyBridge() + base = torch.zeros(4, 4) + converted = { + "decoder.layers.0.mlp.gate_proj.weight": base.clone(), + "decoder.layers.0.mlp.up_proj.weight": base.clone(), + } + + gate0 = torch.arange(0, 8, dtype=base.dtype).reshape(2, 4) + up0 = torch.arange(100, 108, dtype=base.dtype).reshape(2, 4) + gate1 = torch.arange(200, 208, dtype=base.dtype).reshape(2, 4) + up1 = torch.arange(300, 308, dtype=base.dtype).reshape(2, 4) + linear_out = torch.cat([gate0, up0, gate1, up1], dim=0) + adapter_weight = AdapterWeight( + global_base_prefix="decoder.layers.0.mlp.linear_fc1", + adapter_key=None, + alpha=1, + dim=1, + linear_in_weight=MegatronWeightTuple("in", torch.eye(4), vp_stage=0), + linear_out_weight=MegatronWeightTuple("out", linear_out, vp_stage=0), + ) + + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_tensor_model_parallel_world_size", + lambda: 2, + ) + + updated = bridge._merge_lora_adapter_weights( + [Mock(config=SimpleNamespace(num_moe_experts=0))], + converted, + [adapter_weight], + ) + expected_gate = torch.cat([gate0, gate1], dim=0) + expected_up = torch.cat([up0, up1], dim=0) + torch.testing.assert_close(updated["decoder.layers.0.mlp.gate_proj.weight"], expected_gate) + torch.testing.assert_close(updated["decoder.layers.0.mlp.up_proj.weight"], expected_up) + + def test_merge_lora_adapter_weights_qkv_split(monkeypatch): bridge = DummyBridge() config = SimpleNamespace( @@ -98,6 +149,7 @@ def test_merge_lora_adapter_weights_qkv_split(monkeypatch): kv_channels=None, hidden_size=4, attention_output_gate=False, + num_moe_experts=0, ) megatron_model = [SimpleNamespace(config=config)] converted = { @@ -159,7 +211,12 @@ def test_merge_canonical_adapter_from_weights(monkeypatch): linear_out_weight=MegatronWeightTuple("out_v", 3 * torch.ones(1, 2), vp_stage=0), ) - updated = bridge._merge_canonical_adapter_from_weights(converted, [adapter_q, adapter_k, adapter_v]) + megatron_model = [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))] + updated = bridge._merge_canonical_adapter_from_weights( + megatron_model, + converted, + [adapter_q, adapter_k, adapter_v], + ) torch.testing.assert_close(updated["decoder.layers.0.self_attn.q_proj.weight"], torch.ones(2, 2)) torch.testing.assert_close(updated["decoder.layers.0.self_attn.k_proj.weight"], 2 * torch.ones(1, 2)) torch.testing.assert_close(updated["decoder.layers.0.self_attn.v_proj.weight"], 3 * torch.ones(1, 2)) @@ -172,6 +229,9 @@ class DummyGroup: def size(self): return 1 + def rank(self): + return 0 + fake_param = torch.nn.Parameter(torch.zeros(1, 1)) class FakeModel: @@ -216,6 +276,9 @@ class DummyGroup: def size(self): return 1 + def rank(self): + return 0 + class FakeAdapter: def __init__(self): self.linear_in = SimpleNamespace(weight=torch.ones(2, 2)) @@ -432,7 +495,8 @@ def test_stream_adapter_weights_megatron_to_hf(monkeypatch): lambda *_: [adapter_weight], ) - weights = list(bridge.stream_adapter_weights_megatron_to_hf([Mock()], cpu=False, show_progress=False)) + megatron_model = [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))] + weights = list(bridge.stream_adapter_weights_megatron_to_hf(megatron_model, cpu=False, show_progress=False)) assert len(weights) == 2 assert weights[0].param_name.endswith(".linear_in.weight") assert weights[1].param_name.endswith(".linear_out.weight") @@ -494,7 +558,7 @@ def test_stream_adapter_weights_megatron_to_hf_qkv(monkeypatch): weights = list( bridge.stream_adapter_weights_megatron_to_hf( - [SimpleNamespace(config=SimpleNamespace())], + [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))], cpu=False, show_progress=False, ) @@ -560,10 +624,14 @@ def test_stream_adapter_weights_megatron_to_hf_fused_fc1(monkeypatch): "model.layers.0.mlp.up_proj.weight", ], ) + monkeypatch.setattr( + "megatron.bridge.models.conversion.peft_bridge.parallel_state.get_tensor_model_parallel_world_size", + lambda: 1, + ) weights = list( bridge.stream_adapter_weights_megatron_to_hf( - [SimpleNamespace(config=SimpleNamespace())], + [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))], cpu=False, show_progress=False, )