Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions examples/conversion/adapter/stream_adapter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from __future__ import annotations

import argparse
import math
import os
from contextlib import contextmanager
from pathlib import Path
Expand Down Expand Up @@ -162,14 +163,19 @@ def configure_device(device_index: int = 0) -> torch.device:


def calculate_required_world_size(args: argparse.Namespace) -> int:
"""Compute the model-parallel product used to validate distributed setup."""
"""Compute the minimum world size compatible with the requested parallelism.

return (
args.tensor_model_parallel_size
* args.pipeline_model_parallel_size
* args.expert_model_parallel_size
* args.expert_tensor_parallel_size
Megatron requires WORLD_SIZE to be divisible by both the dense TP/PP domain
and the expert ETP/EP/PP domain. Those domains reuse the same global ranks,
so the minimum compatible world size is their least common multiple instead
of the raw product of tp, pp, ep, and etp.
"""

dense_model_parallel_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size
expert_model_parallel_size = (
args.expert_tensor_parallel_size * args.expert_model_parallel_size * args.pipeline_model_parallel_size
)
return math.lcm(dense_model_parallel_size, expert_model_parallel_size)


@contextmanager
Expand All @@ -189,7 +195,7 @@ def distributed_context(
raise RuntimeError(
f"Requested world_size={required_world_size} from model-parallel settings "
f"(tp={tp}, pp={pp}, ep={ep}, etp={etp}), but initialized world_size={world_size}. "
"Launch with torchrun --nproc_per_node equal to the product."
f"Launch with torchrun --nproc_per_node={required_world_size}."
)
yield world_size
return
Expand All @@ -200,7 +206,7 @@ def distributed_context(
if required_world_size > 1 and "WORLD_SIZE" not in os.environ:
raise RuntimeError(
"Distributed world size is greater than 1 but WORLD_SIZE is not set. "
"Launch with torchrun --nproc_per_node equal to the requested world size."
f"Launch with torchrun --nproc_per_node={required_world_size}."
)

if "MASTER_ADDR" in os.environ and "MASTER_PORT" in os.environ:
Expand All @@ -223,7 +229,7 @@ def distributed_context(
raise RuntimeError(
f"Requested world_size={required_world_size} from model-parallel settings "
f"(tp={tp}, pp={pp}, ep={ep}, etp={etp}), but initialized world_size={world_size}. "
"Launch with torchrun --nproc_per_node equal to the product."
f"Launch with torchrun --nproc_per_node={required_world_size}."
)
yield world_size
finally:
Expand Down Expand Up @@ -274,7 +280,7 @@ def stream_and_collect_adapters(
)

for weight_name, tensor in generator:
adapter_state[weight_name] = tensor
adapter_state[weight_name] = tensor.clone()
print_rank_0(f"Collected adapter tensor: {weight_name} with shape {tuple(tensor.shape)}")

if not adapter_state:
Expand All @@ -286,9 +292,7 @@ def stream_and_collect_adapters(
def _normalize_base_weight_name(param_name: str) -> str:
"""Remove the 'base_layer' suffix emitted when merge_adapter_weights=False."""

if param_name.endswith("base_layer.weight"):
return param_name[: -len("base_layer.weight")] + "weight"
return param_name
return param_name.replace(".base_layer.", ".")


def collect_hf_state_dict(
Expand Down Expand Up @@ -327,10 +331,14 @@ def merge_hf_lora_adapters(

for name, tensor in adapter_state.items():
if name.endswith(".lora_A.weight"):
base_name = name[: -len(".lora_A.weight")] + ".weight"
base_name = name[: -len(".lora_A.weight")]
if base_name not in base_state and f"{base_name}.weight" in base_state:
base_name = f"{base_name}.weight"
grouped.setdefault(base_name, {})["A"] = tensor
elif name.endswith(".lora_B.weight"):
base_name = name[: -len(".lora_B.weight")] + ".weight"
base_name = name[: -len(".lora_B.weight")]
if base_name not in base_state and f"{base_name}.weight" in base_state:
base_name = f"{base_name}.weight"
grouped.setdefault(base_name, {})["B"] = tensor

scale = alpha / float(dim)
Expand Down Expand Up @@ -395,7 +403,7 @@ def main() -> None:
f"🧮 Model-parallel settings: tp={args.tensor_model_parallel_size}, "
f"pp={args.pipeline_model_parallel_size}, "
f"ep={args.expert_model_parallel_size}, etp={args.expert_tensor_parallel_size}. "
f"Expected world_size={required_world_size}."
f"Minimum example world_size={required_world_size}."
)

print_rank_0(f"🔧 Loading Hugging Face model {args.hf_model_id} with bfloat16 weights...")
Expand Down
6 changes: 5 additions & 1 deletion src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,11 @@ def stream_weights_megatron_to_hf(
final_tensor = tensor.cpu() if cpu else tensor

if not merge_adapter_weights and "to_wrap.weight" in task.global_param_name:
hf_name = hf_name[: -len("weight")] + "base_layer.weight"
suffix_pos = hf_name.rfind(".")
if suffix_pos == -1:
hf_name = hf_name + ".base_layer"
else:
hf_name = hf_name[:suffix_pos] + ".base_layer" + hf_name[suffix_pos:]

# Handle tied embeddings case
# TODO(yuya): fix this hard coded naming
Expand Down
55 changes: 35 additions & 20 deletions src/megatron/bridge/models/conversion/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2676,12 +2676,26 @@ def merge_gdn_linear_weights(
return in_proj


def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor, tp_size: int = 1) -> torch.Tensor:
"""Split GDN linear weights into QKVZ and BA."""
def split_gdn_linear_weights(
provider: TransformerConfig,
in_proj: torch.Tensor,
tp_size: int = 1,
feature_dim: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Split GDN linear weights into QKVZ and BA.

Args:
provider: Transformer config with GDN dimensions.
in_proj: Packed in-proj tensor.
tp_size: Tensor-parallel world size used for packing layout.
feature_dim: Trailing tensor dimension used for reshape/split.
Defaults to ``provider.hidden_size`` for base weights, but LoRA
paths can pass the adapter rank here.
"""

assert tp_size >= 1, f"tp_size must be greater than 0, but got {tp_size=}"

hidden_size = provider.hidden_size
feature_dim = provider.hidden_size if feature_dim is None else feature_dim
qk_head_dim = provider.linear_key_head_dim
v_head_dim = provider.linear_value_head_dim
num_qk_heads = provider.linear_num_key_heads
Expand All @@ -2690,7 +2704,7 @@ def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor,
qk_dim_local_tp = qk_head_dim * num_qk_heads_local_tp
v_dim_local_tp = v_head_dim * num_v_heads_local_tp

in_proj = in_proj.reshape(tp_size, -1, hidden_size)
in_proj = in_proj.reshape(tp_size, -1, feature_dim)
q, k, v, z, b, a = torch.split(
in_proj,
[
Expand All @@ -2704,12 +2718,12 @@ def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor,
dim=1,
)

q, k, v, z, b, a = [weight.reshape(num_qk_heads, -1, hidden_size) for weight in [q, k, v, z, b, a]]
q, k, v, z, b, a = [weight.reshape(num_qk_heads, -1, feature_dim) for weight in [q, k, v, z, b, a]]
qkvz = torch.cat([q, k, v, z], dim=1)
ba = torch.cat([b, a], dim=1)

qkvz = qkvz.reshape(-1, hidden_size)
ba = ba.reshape(-1, hidden_size)
qkvz = qkvz.reshape(-1, feature_dim)
ba = ba.reshape(-1, feature_dim)

assert qkvz.numel() + ba.numel() == in_proj.numel(), (
f"QKVZBA weights are not correctly split, {qkvz.numel()=}, {ba.numel()=}, {in_proj.numel()=}"
Expand Down Expand Up @@ -2782,14 +2796,15 @@ def _split_gdn_grouped_to_separate(
config: TransformerConfig,
qkvz: torch.Tensor,
ba: torch.Tensor,
feature_dim: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert head-grouped ``qkvz`` and ``ba`` tensors (as produced by
:func:`split_gdn_linear_weights`) back into four flat tensors.

Returns:
Tuple of (qkv, z, b, a) where each tensor has a flat per-component layout.
"""
hidden_size = config.hidden_size
feature_dim = config.hidden_size if feature_dim is None else feature_dim
qk_head_dim = config.linear_key_head_dim
v_head_dim = config.linear_value_head_dim
num_qk_heads = config.linear_num_key_heads
Expand All @@ -2798,31 +2813,31 @@ def _split_gdn_grouped_to_separate(

expected_qkvz_dim0 = num_qk_heads * (qk_head_dim * 2 + v_per_group * v_head_dim * 2)
expected_ba_dim0 = num_qk_heads * v_per_group * 2
if qkvz.ndim != 2 or qkvz.shape[0] != expected_qkvz_dim0 or qkvz.shape[1] != hidden_size:
if qkvz.ndim != 2 or qkvz.shape[0] != expected_qkvz_dim0 or qkvz.shape[1] != feature_dim:
raise ValueError(
f"qkvz shape mismatch: expected ({expected_qkvz_dim0}, {hidden_size}), got {tuple(qkvz.shape)}"
f"qkvz shape mismatch: expected ({expected_qkvz_dim0}, {feature_dim}), got {tuple(qkvz.shape)}"
)
if ba.ndim != 2 or ba.shape[0] != expected_ba_dim0 or ba.shape[1] != hidden_size:
raise ValueError(f"ba shape mismatch: expected ({expected_ba_dim0}, {hidden_size}), got {tuple(ba.shape)}")
if ba.ndim != 2 or ba.shape[0] != expected_ba_dim0 or ba.shape[1] != feature_dim:
raise ValueError(f"ba shape mismatch: expected ({expected_ba_dim0}, {feature_dim}), got {tuple(ba.shape)}")

# --- Split grouped QKVZ ---
qkvz_g = qkvz.reshape(num_qk_heads, -1, hidden_size)
qkvz_g = qkvz.reshape(num_qk_heads, -1, feature_dim)
q_g, k_g, v_g, z_g = torch.split(
qkvz_g,
[qk_head_dim, qk_head_dim, v_per_group * v_head_dim, v_per_group * v_head_dim],
dim=1,
)
q_flat = q_g.reshape(-1, hidden_size)
k_flat = k_g.reshape(-1, hidden_size)
v_flat = v_g.reshape(-1, hidden_size)
z_flat = z_g.reshape(-1, hidden_size)
q_flat = q_g.reshape(-1, feature_dim)
k_flat = k_g.reshape(-1, feature_dim)
v_flat = v_g.reshape(-1, feature_dim)
z_flat = z_g.reshape(-1, feature_dim)
qkv = torch.cat([q_flat, k_flat, v_flat], dim=0)

# --- Split grouped BA ---
ba_g = ba.reshape(num_qk_heads, -1, hidden_size)
ba_g = ba.reshape(num_qk_heads, -1, feature_dim)
b_g, a_g = torch.split(ba_g, [v_per_group, v_per_group], dim=1)
b_flat = b_g.reshape(-1, hidden_size)
a_flat = a_g.reshape(-1, hidden_size)
b_flat = b_g.reshape(-1, feature_dim)
a_flat = a_g.reshape(-1, feature_dim)

return qkv, z_flat, b_flat, a_flat

Expand Down
Loading
Loading