Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 213 additions & 64 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

"""Code that export quantized Megatron Core models for deployment."""

import io
import json
import os
import tempfile
Expand Down Expand Up @@ -85,6 +86,10 @@
HybridModel = MambaModel
from megatron.core.models.multimodal.llava_model import LLaVAModel
from megatron.core.parallel_state import (
get_data_parallel_rank,
get_expert_model_parallel_group,
get_expert_model_parallel_rank,
get_expert_model_parallel_world_size,
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -300,50 +305,61 @@ def save_pretrained(
# We use the last PP rank and the 1st EP rank to write the config because
# medusa_heads and eagle_module only exist in the last stage.
if is_last_stage_main_rank:
# Baseline: for a local source, copy every non-safetensors file
# (tokenizer, remote_code *.py, README, etc.); for a Hub-ID source,
# snapshot_download just the *.py sidecars (tokenizer comes via the
# AutoTokenizer fallback below). modelopt-owned files (config.json,
# generation_config.json, hf_quant_config.json, preprocessor_config.json)
# are overwritten below.
if self._hf_pretrained_model_name is not None:
if os.path.isdir(self._hf_pretrained_model_name):
copy_non_safetensor_files_from_ckpt(
self._hf_pretrained_model_name, save_directory
)
else:
copy_hf_ckpt_remote_code(self._hf_pretrained_model_name, save_directory)
self._hf_config.save_pretrained(save_directory)
try:
generation_config = transformers.GenerationConfig.from_pretrained(
self._hf_pretrained_model_name,
trust_remote_code=self.trust_remote_code,
)
generation_config.save_pretrained(save_directory)
except OSError:
pass
# Hub-ID / None source: fetch tokenizer files via AutoTokenizer.
if self._hf_pretrained_model_name is None or not os.path.isdir(
self._hf_pretrained_model_name
):
# Sidecar file writes (tokenizer, generation_config, hf_config,
# preprocessor, etc.) all target the same save_directory. At DP > 1
# or EP > 1 there are multiple `is_last_stage_main_rank` ranks; if
# they all race on these writes the contents stay correct (last
# writer wins with identical bytes) but partial-write windows can
# still corrupt a concurrent read by another rank. Pin to one rank.
if get_data_parallel_rank() == 0 and get_expert_model_parallel_rank() == 0:
# Baseline: for a local source, copy every non-safetensors file
# (tokenizer, remote_code *.py, README, etc.); for a Hub-ID source,
# snapshot_download just the *.py sidecars (tokenizer comes via the
# AutoTokenizer fallback below). modelopt-owned files (config.json,
# generation_config.json, hf_quant_config.json, preprocessor_config.json)
# are overwritten below.
if self._hf_pretrained_model_name is not None:
if os.path.isdir(self._hf_pretrained_model_name):
copy_non_safetensor_files_from_ckpt(
self._hf_pretrained_model_name, save_directory
)
else:
copy_hf_ckpt_remote_code(self._hf_pretrained_model_name, save_directory)
self._hf_config.save_pretrained(save_directory)
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
generation_config = transformers.GenerationConfig.from_pretrained(
self._hf_pretrained_model_name,
trust_remote_code=self.trust_remote_code,
)
tokenizer.save_pretrained(save_directory)
except (OSError, TypeError, ValueError, ImportError):
generation_config.save_pretrained(save_directory)
except OSError:
pass
# Hub-ID / None source: fetch tokenizer files via AutoTokenizer.
if self._hf_pretrained_model_name is None or not os.path.isdir(
self._hf_pretrained_model_name
):
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
self._hf_pretrained_model_name,
trust_remote_code=self.trust_remote_code,
)
tokenizer.save_pretrained(save_directory)
except (OSError, TypeError, ValueError, ImportError):
pass
try:
# Load and save preprocessor config from the original model
processor = AutoProcessor.from_pretrained(
self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code
)
if hasattr(processor, "image_processor"):
processor.image_processor.save_pretrained(save_directory)
except (OSError, ValueError, ImportError):
pass
try:
# Load and save preprocessor config from the original model
processor = AutoProcessor.from_pretrained(
self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code
)
if hasattr(processor, "image_processor"):
processor.image_processor.save_pretrained(save_directory)
except (OSError, ValueError, ImportError):
pass

# MTP load is in-memory per-rank state (mutates this rank's
# layer_state_dicts that feeds the per-rank safetensors save
# later); it must run on every last-stage main rank, not just
# the single writer.
mtp_state_dict = self._get_mtp_state_dict()
if len(mtp_state_dict) > 0:
layer_state_dicts[self.model.config.num_layers].update(mtp_state_dict)
Expand All @@ -354,7 +370,17 @@ def save_pretrained(
# kv_cache_dtype is only set on attention-owning ranks; writer rank may not be one.
gathered_kv_cache_dtype = self._gather_kv_cache_dtype()

if is_last_stage_main_rank and quantization is not None:
# Pin hf_quant_config.json assembly + write to a single rank for the same
# reason as the sidecar block above (and for symmetry with the config.json
# patch below). Self._hf_quant_config remains empty on non-writer ranks,
# which makes the config.json gate at the bottom of this method naturally
# a no-op on them.
if (
is_last_stage_main_rank
and get_data_parallel_rank() == 0
and get_expert_model_parallel_rank() == 0
and quantization is not None
):
if combined_layer_config_dict:
quantization_config = process_layer_quant_config(combined_layer_config_dict)
quantization_config["exclude_modules"] = combined_exclude_modules
Expand Down Expand Up @@ -400,16 +426,31 @@ def save_pretrained(
# Barrier to ensure the export_dir has been created.
torch.distributed.barrier()

# Newer versions of VLLM expect config.json with hf_quant_config
# Newer versions of VLLM expect config.json with hf_quant_config.
# Pin the writer to a single rank: last PP stage, TP rank 0 (handled by
# `is_last_stage_main_rank`), DP rank 0, EP rank 0. Without the DP/EP gates,
# multiple ranks satisfy `is_last_stage_main_rank` simultaneously (one per
# DPxEP cell), read-modify-write the same file, and any pair can interleave
# such that another rank reads a truncated mid-write file and raises
# JSONDecodeError. Bracket with barriers so every other rank waits for the
# single writer.
torch.distributed.barrier()
config_json_file = save_directory + "/config.json"
if self._hf_quant_config and os.path.exists(config_json_file):
if (
is_last_stage_main_rank
and get_data_parallel_rank() == 0
and get_expert_model_parallel_rank() == 0
and self._hf_quant_config
and os.path.exists(config_json_file)
):
with open(config_json_file) as f:
config_dict = json.load(f)
config_dict["quantization_config"] = convert_hf_quant_config_format(
self._hf_quant_config
)
with open(config_json_file, "w") as f:
json.dump(config_dict, f, indent=4)
torch.distributed.barrier()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# save_safetensors(state_dict, save_directory)
save_safetensors_by_layer_index(
Expand Down Expand Up @@ -614,7 +655,7 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
)
single_safetensors_file = None
except EntryNotFoundError:
# Model uses a single unsharded safetensors file check it for MTP weights.
# Model uses a single unsharded safetensors file -- check it for MTP weights.
safetensors_index_file = None
try:
single_safetensors_file = Path(
Expand Down Expand Up @@ -1040,6 +1081,17 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None):
quantization state from the module, then iterates over experts and saves each expert's
weight (and scales if quantized) under the HF-style per-expert prefix.

Collective behavior (EP > 1):
When the expert-model-parallel world size is greater than 1, ``num_experts`` is the
number of *local* experts on this rank, and ``range(num_experts)`` numbers them
0..N-1, which collides with every other EP rank. This method translates local
expert ids to *global* ids (preferring ``module.local_expert_indices`` when
exposed, falling back to the standard Megatron contiguous layout
``[ep_rank*N, (ep_rank+1)*N - 1]``) and ``all_gather_object`` the resulting
per-expert state across the EP process group so every rank in the EP group ends
up with the union of experts for this layer. All ranks in the EP group MUST call
this function for the same layer at the same logical step or the gather will hang.

This is the reverse of _grouped_mlp_merging in the importer.
"""
num_experts = module.num_gemms
Expand All @@ -1062,37 +1114,134 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None):

state_dict = module.state_dict()

for expert_id in range(num_experts):
expert_prefix = prefix.format(expert_id) + "."
self._record_layer_quant_config(expert_prefix, qformat, block_size)
weight_key = f"weight{expert_id}"
ep_size = (
get_expert_model_parallel_world_size() if torch.distributed.is_initialized() else 1
)
ep_rank = get_expert_model_parallel_rank() if torch.distributed.is_initialized() else 0

# Prefer the model's authoritative local-to-global expert mapping if exposed.
# The standard Megatron contiguous layout (rank r owns experts [r*N, (r+1)*N - 1])
# is the fallback when the module doesn't expose local_expert_indices.
#
# Normalize to a plain Python list of ints regardless of whether Megatron exposes
# this as a list, tuple, torch.Tensor, or numpy array. A bare `or` against the raw
# attribute trips `bool(tensor)` for multi-element tensors, and an empty list (or
# 0-d tensor) would silently fall through to the contiguous-layout fallback.
indices = getattr(module, "local_expert_indices", None)
if indices is None and getattr(module, "experts", None) is not None:
indices = getattr(module.experts, "local_expert_indices", None)
if indices is None:
local_expert_indices = [ep_rank * num_experts + i for i in range(num_experts)]
elif isinstance(indices, torch.Tensor):
local_expert_indices = indices.detach().cpu().tolist()
else:
local_expert_indices = [int(i) for i in indices]
if len(local_expert_indices) != num_experts:
raise ValueError(
f"local_expert_indices length {len(local_expert_indices)} doesn't match "
f"module.num_gemms {num_experts}"
)

if weight_key not in state_dict:
raise ValueError(f"Missing expected TEGroupedMLP expert weight: {weight_key}")
# Collective-safe missing-key detection: every EP rank checks for its weight keys
# locally, then we all_reduce(MAX) over a 0/1 flag so a single rank's missing key
# surfaces on every rank. Tensor must be on the current CUDA device -- NCCL has no
# CPU backend on the EP process group.
local_missing = [
k for k in (f"weight{i}" for i in range(num_experts)) if k not in state_dict
]
if ep_size > 1:
missing_flag = torch.tensor(
[1 if local_missing else 0],
dtype=torch.int32,
device=torch.cuda.current_device(),
)
torch.distributed.all_reduce(
missing_flag,
op=torch.distributed.ReduceOp.MAX,
group=get_expert_model_parallel_group(),
)
if missing_flag.item() != 0:
raise ValueError(
f"TEGroupedMLP missing expert weights on at least one EP rank "
f"(local missing on rank {ep_rank}: {local_missing})"
)
elif local_missing:
raise ValueError(f"TEGroupedMLP missing expert weights: {local_missing}")

# Scales / aux values are small and shared across local experts. Move to CPU once
# so the gather payload doesn't repeatedly clone GPU tensors.
weight_scale_cpu = weight_scale.detach().cpu().clone() if weight_scale is not None else None
weight_scale_2_cpu = (
weight_scale_2.detach().cpu().clone() if weight_scale_2 is not None else None
)
name_to_value_cpu = {
k: v.detach().cpu().clone() for k, v in name_to_value.items() if k != "output_scale"
}

# Record per-layer quant config for ALL global experts on every rank, not just
# the local slice -- `_per_layer_quant_config` is a local in-memory dict that
# later gets serialized into hf_quant_config.json; if we only recorded the
# local 1/EP slice, the writer rank's dict would be missing (EP-1)/EP of the
# routed-expert entries. Within a single TEGroupedMLP layer all routed experts
# share the same qformat / block_size by construction (one quantizer pattern
# in the recipe matches the whole `*mixer.experts.*` glob), so it's safe to
# use the local qformat/block_size for the global record.
num_total_experts = num_experts * ep_size
for global_id in range(num_total_experts):
self._record_layer_quant_config(prefix.format(global_id) + ".", qformat, block_size)

local_expert_state: dict[str, torch.Tensor] = {}

for local_id in range(num_experts):
global_id = local_expert_indices[local_id]
expert_prefix = prefix.format(global_id) + "."
weight_key = f"weight{local_id}"

weight = state_dict[weight_key].to(self.dtype).cpu()

if weight_scale is None:
self._state_dict[expert_prefix + "weight"] = weight
if weight_scale_cpu is None:
local_expert_state[expert_prefix + "weight"] = weight
else:
self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
local_expert_state[expert_prefix + "weight"] = to_quantized_weight(
weight,
weight_scale,
weight_scale_cpu,
qformat,
weight_scale_2,
weight_scale_2_cpu,
block_size,
)
self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()

if weight_scale_2 is not None:
self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()

for key, val in name_to_value.items():
if key == "output_scale":
continue
for expert_id in range(num_experts):
expert_prefix = prefix.format(expert_id) + "."
self._state_dict[expert_prefix + key] = val.detach().clone()
local_expert_state[expert_prefix + "weight_scale"] = weight_scale_cpu.clone()

if weight_scale_2_cpu is not None:
local_expert_state[expert_prefix + "weight_scale_2"] = weight_scale_2_cpu.clone()

for key, val in name_to_value_cpu.items():
local_expert_state[expert_prefix + key] = val.clone()

if ep_size > 1:
# all_gather_object can pickle Python objects but trips on quantized uint8
# tensors whose UntypedStorage has no dtype attr. Round-trip through
# torch.save's byte stream -- PyTorch's own tensor codec handles them.
_buf = io.BytesIO()
torch.save(local_expert_state, _buf)
local_bytes = _buf.getvalue()
del _buf
gathered_bytes: list = [None] * ep_size
torch.distributed.all_gather_object(
gathered_bytes, local_bytes, group=get_expert_model_parallel_group()
)
del local_bytes
for b in gathered_bytes:
# weights_only=False: payload is generated by us via torch.save in this
# same function on a sibling rank in the EP process group of this job
# -- it never leaves the cluster's collective and is not user-supplied.
# weights_only=True would refuse to deserialize the dict[str, Tensor]
# because quantized uint8 tensors store custom storage metadata that
# the safe-loader allowlist doesn't cover.
s = torch.load(io.BytesIO(b), map_location="cpu", weights_only=False)
self._state_dict.update(s)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
del gathered_bytes
else:
self._state_dict.update(local_expert_state)

def _qkv_slicing(
self,
Expand Down
Loading