Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 197 additions & 23 deletions src/megatron/bridge/models/conversion/peft_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import itertools
import re
from collections import defaultdict
from dataclasses import dataclass
from string import digits
Expand Down Expand Up @@ -259,6 +260,14 @@ def _infer_qkv_projection_from_name(self, hf_name: str) -> Optional[str]:
return "v_proj"
return None

def _infer_hf_expert_idx(self, hf_name: str) -> Optional[int]:
"""Return the expert index embedded in an HF MoE weight name."""

match = re.search(r"\bexperts\.(\d+)\b", hf_name)
if match is None:
return None
return int(match.group(1))

def _split_qkv_linear_out_weight(
self,
megatron_model: Union[MegatronModel, List[MegatronModel]],
Expand All @@ -270,6 +279,70 @@ def _split_qkv_linear_out_weight(
q_out, k_out, v_out = split_qkv_weights(model.config, linear_out_weight)
return {"q_proj": q_out, "k_proj": k_out, "v_proj": v_out}

def _split_fused_fc1_linear_out_weight(
self,
linear_out_weight: torch.Tensor,
*,
is_expert: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Split fused FC1 LoRA linear_out into gate/up with TP-aware ordering."""

tp_size = (
parallel_state.get_expert_tensor_parallel_world_size()
if is_expert
else parallel_state.get_tensor_model_parallel_world_size()
)
if tp_size <= 1:
return torch.chunk(linear_out_weight, 2, dim=0)

shard_size = linear_out_weight.shape[0] // tp_size
if shard_size * tp_size != linear_out_weight.shape[0] or shard_size % 2 != 0:
return torch.chunk(linear_out_weight, 2, dim=0)

shards = torch.split(linear_out_weight, shard_size, dim=0)
gate_parts = []
up_parts = []
for shard in shards:
gate_shard, up_shard = torch.chunk(shard, 2, dim=0)
gate_parts.append(gate_shard)
up_parts.append(up_shard)
gate = torch.cat(gate_parts, dim=0)
up = torch.cat(up_parts, dim=0)
return gate, up

def _gather_expert_adapter_weight(
self,
weight: torch.Tensor,
) -> Optional[List[torch.Tensor]]:
"""Gather expert-sharded adapter weights across EP ranks when needed."""
ep_size = parallel_state.get_expert_model_parallel_world_size()
if ep_size <= 1:
return None
assert weight.ndim < 3

gathered = [torch.empty_like(weight) for _ in range(ep_size)]
torch.distributed.all_gather(gathered, weight, group=parallel_state.get_expert_model_parallel_group())
return gathered

def _select_expert_adapter_weight(
self,
weight: torch.Tensor,
gathered: List[torch.Tensor],
expert_idx: int,
num_experts: int,
) -> torch.Tensor:
"""Select the per-expert adapter weight slice if present."""

assert weight.ndim < 3

ep_size = parallel_state.get_expert_model_parallel_world_size()
if ep_size <= 1:
return weight

num_experts_per_rank = num_experts // ep_size
rank = expert_idx // num_experts_per_rank
return gathered[rank]

def _megatron_global_adapters_info_all_pp_ranks(
self, megatron_model: Union[MegatronModel, List[MegatronModel]]
) -> List[tuple[str, str, bool, bool, int, int, int, int]]:
Expand Down Expand Up @@ -379,7 +452,7 @@ def build_adapter_conversion_tasks(
local_linear_in_name, local_linear_out_name = global_linear_in_name, global_linear_out_name

base_suffix = ".weight"
if is_expert_linear(global_base_prefix):
if is_expert_linear(global_base_prefix) and ".local_experts." not in global_base_prefix:
# To get expert layer hf mapping properly
base_suffix = ".weight0"

Expand Down Expand Up @@ -514,15 +587,44 @@ def stream_adapter_weights_megatron_to_hf(

linear_in_tensor = adapter_weight.linear_in_weight.weight
linear_out_tensor = adapter_weight.linear_out_weight.weight
if cpu:
linear_in_tensor = linear_in_tensor.cpu()
linear_out_tensor = linear_out_tensor.cpu()
is_expert = is_expert_linear(adapter_task.global_base_prefix)
is_grouped_expert = is_expert and ".local_experts." not in adapter_task.global_base_prefix
expert_linear_in_gathered = None
expert_linear_out_gathered = None
if is_grouped_expert:
expert_linear_in_gathered = self._gather_expert_adapter_weight(
linear_in_tensor,
)
expert_linear_out_gathered = self._gather_expert_adapter_weight(
linear_out_tensor,
)

base_suffixes = [".weight"]
if is_expert_linear(adapter_task.global_base_prefix):
if is_grouped_expert:
base_suffixes = [f".weight{expert_num}" for expert_num in range(num_moe_experts)]

for base_suffix in base_suffixes:
current_linear_in_tensor = linear_in_tensor
current_linear_out_tensor = linear_out_tensor
if is_grouped_expert:
expert_idx = int(base_suffix[len(".weight") :])
current_linear_in_tensor = self._select_expert_adapter_weight(
linear_in_tensor,
expert_linear_in_gathered,
expert_idx,
num_moe_experts,
)
current_linear_out_tensor = self._select_expert_adapter_weight(
linear_out_tensor,
expert_linear_out_gathered,
expert_idx,
num_moe_experts,
)

if cpu:
current_linear_in_tensor = current_linear_in_tensor.cpu()
current_linear_out_tensor = current_linear_out_tensor.cpu()

base_hf_weight_names = self._get_base_hf_weight_names_for_adapter(
mapping_registry,
adapter_task.global_base_prefix,
Expand All @@ -541,25 +643,27 @@ def stream_adapter_weights_megatron_to_hf(
per_base_linear_out = self._get_fused_adapter_linear_out_slices(
megatron_model,
base_hf_weight_names,
linear_out_tensor,
current_linear_out_tensor,
is_expert=is_expert_linear(adapter_task.global_base_prefix),
)
if per_base_linear_out is not None:
for index, base_name in enumerate(base_hf_weight_names):
current_linear_out = per_base_linear_out.get(base_name)
assert current_linear_out is not None, "unknown projection name"
current_linear_out_tensor = per_base_linear_out.get(base_name)
assert current_linear_out_tensor is not None, "unknown projection name"

yield HFWeightTuple(linear_in_hf_names[index], linear_in_tensor)
yield HFWeightTuple(linear_out_hf_names[index], current_linear_out)
yield HFWeightTuple(linear_in_hf_names[index], current_linear_in_tensor)
yield HFWeightTuple(linear_out_hf_names[index], current_linear_out_tensor)
continue

yield HFWeightTuple(linear_in_hf_names[0], linear_in_tensor)
yield HFWeightTuple(linear_out_hf_names[0], linear_out_tensor)
yield HFWeightTuple(linear_in_hf_names[0], current_linear_in_tensor)
yield HFWeightTuple(linear_out_hf_names[0], current_linear_out_tensor)

def _get_fused_adapter_linear_out_slices(
self,
megatron_model: List[MegatronModel],
base_hf_weight_names: List[str],
linear_out_tensor: torch.Tensor,
is_expert: bool = False,
) -> Optional[Dict[str, torch.Tensor]]:
"""Return per-base-name linear_out slices for fused adapters, else None.

Expand All @@ -580,13 +684,16 @@ def _get_fused_adapter_linear_out_slices(

is_fused_fc1 = self._is_fused_fc1_gate_up(base_hf_weight_names, linear_out_tensor)
if is_fused_fc1:
split_size = linear_out_tensor.shape[0] // 2
gate_weight, up_weight = self._split_fused_fc1_linear_out_weight(
linear_out_tensor,
is_expert=is_expert,
)
per_base = {}
for base_name in base_hf_weight_names:
if "gate_proj" in base_name:
per_base[base_name] = linear_out_tensor[:split_size, :]
per_base[base_name] = gate_weight
elif "up_proj" in base_name:
per_base[base_name] = linear_out_tensor[split_size:, :]
per_base[base_name] = up_weight
else:
raise ValueError(f"Unknown fused-fc1 base weight name: {base_name}")
return per_base
Expand All @@ -604,31 +711,64 @@ def _merge_lora_adapter_weights(
if len(adapter_weights) > 1 and all(
w.adapter_key in ADAPTER_NAME_MAP.values() for w in adapter_weights if w.adapter_key
):
return self._merge_canonical_adapter_from_weights(converted_weights_dict, adapter_weights)
return self._merge_canonical_adapter_from_weights(megatron_model, converted_weights_dict, adapter_weights)

assert len(adapter_weights) == 1, "Expected a single adapter weight for standard LoRA merging"

adapter_weight = adapter_weights[0]
alpha, dim = adapter_weight.alpha, adapter_weight.dim
linear_in_weight = adapter_weight.linear_in_weight.weight
linear_out_weight = adapter_weight.linear_out_weight.weight
num_moe_experts = megatron_model[0].config.num_moe_experts
is_expert = is_expert_linear(adapter_weight.global_base_prefix)
is_grouped_expert = is_expert and ".local_experts." not in adapter_weight.global_base_prefix
expert_linear_in_gathered = None
expert_linear_out_gathered = None
if is_grouped_expert:
expert_linear_in_gathered = self._gather_expert_adapter_weight(linear_in_weight)
expert_linear_out_gathered = self._gather_expert_adapter_weight(linear_out_weight)

base_weight_shape = next(iter(converted_weights_dict.values())).shape
weight_names = converted_weights_dict.keys()
is_fused_fc1 = self._is_fused_fc1_gate_up(weight_names, linear_out_weight, base_weight_shape)
is_fused_qkv = self._is_fused_qkv(weight_names)
is_fused_qkv = self._is_fused_qkv(weight_names) and not is_expert
qkv_linear_out_weights = (
self._split_qkv_linear_out_weight(megatron_model, linear_out_weight) if is_fused_qkv else None
)
fc1_gate_weight = fc1_up_weight = None
if is_fused_fc1 and not is_expert:
fc1_gate_weight, fc1_up_weight = self._split_fused_fc1_linear_out_weight(
linear_out_weight,
is_expert=is_expert,
)

for hf_name, base_weight in list(converted_weights_dict.items()):
current_linear_in_weight = linear_in_weight
current_linear_out_weight = linear_out_weight
if is_grouped_expert:
expert_idx = self._infer_hf_expert_idx(hf_name)
current_linear_in_weight = self._select_expert_adapter_weight(
linear_in_weight,
expert_linear_in_gathered,
expert_idx,
num_moe_experts,
)
current_linear_out_weight = self._select_expert_adapter_weight(
linear_out_weight,
expert_linear_out_gathered,
expert_idx,
num_moe_experts,
)
if is_fused_fc1:
split_size = linear_out_weight.shape[0] // 2
if is_expert:
fc1_gate_weight, fc1_up_weight = self._split_fused_fc1_linear_out_weight(
current_linear_out_weight,
is_expert=is_expert,
)
if "gate_proj" in hf_name:
current_linear_out_weight = linear_out_weight[:split_size, :]
current_linear_out_weight = fc1_gate_weight
elif "up_proj" in hf_name:
current_linear_out_weight = linear_out_weight[split_size:, :]
current_linear_out_weight = fc1_up_weight
else:
raise ValueError(f"Unknown weight name: {hf_name}")
elif is_fused_qkv and qkv_linear_out_weights is not None:
Expand All @@ -638,7 +778,7 @@ def _merge_lora_adapter_weights(
current_linear_out_weight = qkv_linear_out_weights[projection_key]

merged_weight = self._merge_single_adapter_weight(
base_weight, alpha, dim, linear_in_weight, current_linear_out_weight
base_weight, alpha, dim, current_linear_in_weight, current_linear_out_weight
)
converted_weights_dict[hf_name] = merged_weight

Expand Down Expand Up @@ -672,29 +812,63 @@ def _merge_single_adapter_weight(

def _merge_canonical_adapter_from_weights(
self,
megatron_model: List[MegatronModel],
converted_weights_dict: Dict[str, torch.Tensor],
adapter_weights: List[AdapterWeight],
) -> Dict[str, torch.Tensor]:
"""Merge CanonicalLoRA adapters using pre-materialized adapter weights."""

adapter_lookup = {aw.adapter_key: aw for aw in adapter_weights}
expert_linear_in_gathered: Dict[str, List[torch.Tensor]] = {}
expert_linear_out_gathered: Dict[str, List[torch.Tensor]] = {}
base_prefix = adapter_weights[0].global_base_prefix
num_moe_experts = megatron_model[0].config.num_moe_experts
is_expert = is_expert_linear(base_prefix)
is_grouped_expert = is_expert and ".local_experts." not in base_prefix
if is_grouped_expert:
for adapter_key, adapter_weight in adapter_lookup.items():
expert_linear_in_gathered[adapter_key] = self._gather_expert_adapter_weight(
adapter_weight.linear_in_weight.weight,
)
expert_linear_out_gathered[adapter_key] = self._gather_expert_adapter_weight(
adapter_weight.linear_out_weight.weight,
)

for hf_name, base_weight in converted_weights_dict.items():
target_adapter = None
target_adapter_key = None
for suffix, adapter_key in ADAPTER_NAME_MAP.items():
if hf_name.endswith(suffix):
target_adapter = adapter_lookup.get(adapter_key)
target_adapter_key = adapter_key
break

if target_adapter is None:
raise ValueError(f"Adapter name mapping not found for {hf_name}")

linear_in_weight = target_adapter.linear_in_weight.weight
linear_out_weight = target_adapter.linear_out_weight.weight
if is_grouped_expert:
expert_idx = self._infer_hf_expert_idx(hf_name)
linear_in_weight = self._select_expert_adapter_weight(
linear_in_weight,
expert_linear_in_gathered.get(target_adapter_key),
expert_idx,
num_moe_experts,
)
linear_out_weight = self._select_expert_adapter_weight(
linear_out_weight,
expert_linear_out_gathered.get(target_adapter_key),
expert_idx,
num_moe_experts,
)

merged_weight = self._merge_single_adapter_weight(
base_weight,
target_adapter.alpha,
target_adapter.dim,
target_adapter.linear_in_weight.weight,
target_adapter.linear_out_weight.weight,
linear_in_weight,
linear_out_weight,
)
converted_weights_dict[hf_name] = merged_weight

Expand Down
8 changes: 6 additions & 2 deletions tests/unit_tests/models/test_auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,9 @@ def test_save_hf_pretrained(self, mock_is_init, mock_is_avail, mock_barrier, moc

# Check artifacts were saved on rank 0
mock_hf_model.save_artifacts.assert_called_once_with("./output_dir", original_source_path=None)
mock_save_hf_weights.assert_called_once_with(mock_megatron_model, "./output_dir", True, True)
mock_save_hf_weights.assert_called_once_with(
mock_megatron_model, "./output_dir", True, True, merge_adapter_weights=True
)

@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
Expand All @@ -566,7 +568,9 @@ def test_save_hf_pretrained_non_zero_rank(

# Artifacts should NOT be saved on non-zero rank
mock_hf_model.save_artifacts.assert_not_called()
mock_save_hf_weights.assert_called_once_with(mock_megatron_model, "./output_dir", True, True)
mock_save_hf_weights.assert_called_once_with(
mock_megatron_model, "./output_dir", True, True, merge_adapter_weights=True
)

def test_export_hf_weights(self):
"""Test exporting weights from Megatron to HF format."""
Expand Down
Loading