diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 3e111381adc..00bf55e424d 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -18,6 +18,7 @@ """Code that export quantized Megatron Core models for deployment.""" +import io import json import os import tempfile @@ -85,6 +86,10 @@ HybridModel = MambaModel from megatron.core.models.multimodal.llava_model import LLaVAModel from megatron.core.parallel_state import ( + get_data_parallel_rank, + get_expert_model_parallel_group, + get_expert_model_parallel_rank, + get_expert_model_parallel_world_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, get_tensor_model_parallel_rank, @@ -300,50 +305,61 @@ def save_pretrained( # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: - # Baseline: for a local source, copy every non-safetensors file - # (tokenizer, remote_code *.py, README, etc.); for a Hub-ID source, - # snapshot_download just the *.py sidecars (tokenizer comes via the - # AutoTokenizer fallback below). modelopt-owned files (config.json, - # generation_config.json, hf_quant_config.json, preprocessor_config.json) - # are overwritten below. - if self._hf_pretrained_model_name is not None: - if os.path.isdir(self._hf_pretrained_model_name): - copy_non_safetensor_files_from_ckpt( - self._hf_pretrained_model_name, save_directory - ) - else: - copy_hf_ckpt_remote_code(self._hf_pretrained_model_name, save_directory) - self._hf_config.save_pretrained(save_directory) - try: - generation_config = transformers.GenerationConfig.from_pretrained( - self._hf_pretrained_model_name, - trust_remote_code=self.trust_remote_code, - ) - generation_config.save_pretrained(save_directory) - except OSError: - pass - # Hub-ID / None source: fetch tokenizer files via AutoTokenizer. - if self._hf_pretrained_model_name is None or not os.path.isdir( - self._hf_pretrained_model_name - ): + # Sidecar file writes (tokenizer, generation_config, hf_config, + # preprocessor, etc.) all target the same save_directory. At DP > 1 + # or EP > 1 there are multiple `is_last_stage_main_rank` ranks; if + # they all race on these writes the contents stay correct (last + # writer wins with identical bytes) but partial-write windows can + # still corrupt a concurrent read by another rank. Pin to one rank. + if get_data_parallel_rank() == 0 and get_expert_model_parallel_rank() == 0: + # Baseline: for a local source, copy every non-safetensors file + # (tokenizer, remote_code *.py, README, etc.); for a Hub-ID source, + # snapshot_download just the *.py sidecars (tokenizer comes via the + # AutoTokenizer fallback below). modelopt-owned files (config.json, + # generation_config.json, hf_quant_config.json, preprocessor_config.json) + # are overwritten below. + if self._hf_pretrained_model_name is not None: + if os.path.isdir(self._hf_pretrained_model_name): + copy_non_safetensor_files_from_ckpt( + self._hf_pretrained_model_name, save_directory + ) + else: + copy_hf_ckpt_remote_code(self._hf_pretrained_model_name, save_directory) + self._hf_config.save_pretrained(save_directory) try: - tokenizer = transformers.AutoTokenizer.from_pretrained( + generation_config = transformers.GenerationConfig.from_pretrained( self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code, ) - tokenizer.save_pretrained(save_directory) - except (OSError, TypeError, ValueError, ImportError): + generation_config.save_pretrained(save_directory) + except OSError: + pass + # Hub-ID / None source: fetch tokenizer files via AutoTokenizer. + if self._hf_pretrained_model_name is None or not os.path.isdir( + self._hf_pretrained_model_name + ): + try: + tokenizer = transformers.AutoTokenizer.from_pretrained( + self._hf_pretrained_model_name, + trust_remote_code=self.trust_remote_code, + ) + tokenizer.save_pretrained(save_directory) + except (OSError, TypeError, ValueError, ImportError): + pass + try: + # Load and save preprocessor config from the original model + processor = AutoProcessor.from_pretrained( + self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code + ) + if hasattr(processor, "image_processor"): + processor.image_processor.save_pretrained(save_directory) + except (OSError, ValueError, ImportError): pass - try: - # Load and save preprocessor config from the original model - processor = AutoProcessor.from_pretrained( - self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code - ) - if hasattr(processor, "image_processor"): - processor.image_processor.save_pretrained(save_directory) - except (OSError, ValueError, ImportError): - pass + # MTP load is in-memory per-rank state (mutates this rank's + # layer_state_dicts that feeds the per-rank safetensors save + # later); it must run on every last-stage main rank, not just + # the single writer. mtp_state_dict = self._get_mtp_state_dict() if len(mtp_state_dict) > 0: layer_state_dicts[self.model.config.num_layers].update(mtp_state_dict) @@ -354,7 +370,17 @@ def save_pretrained( # kv_cache_dtype is only set on attention-owning ranks; writer rank may not be one. gathered_kv_cache_dtype = self._gather_kv_cache_dtype() - if is_last_stage_main_rank and quantization is not None: + # Pin hf_quant_config.json assembly + write to a single rank for the same + # reason as the sidecar block above (and for symmetry with the config.json + # patch below). Self._hf_quant_config remains empty on non-writer ranks, + # which makes the config.json gate at the bottom of this method naturally + # a no-op on them. + if ( + is_last_stage_main_rank + and get_data_parallel_rank() == 0 + and get_expert_model_parallel_rank() == 0 + and quantization is not None + ): if combined_layer_config_dict: quantization_config = process_layer_quant_config(combined_layer_config_dict) quantization_config["exclude_modules"] = combined_exclude_modules @@ -400,9 +426,23 @@ def save_pretrained( # Barrier to ensure the export_dir has been created. torch.distributed.barrier() - # Newer versions of VLLM expect config.json with hf_quant_config + # Newer versions of VLLM expect config.json with hf_quant_config. + # Pin the writer to a single rank: last PP stage, TP rank 0 (handled by + # `is_last_stage_main_rank`), DP rank 0, EP rank 0. Without the DP/EP gates, + # multiple ranks satisfy `is_last_stage_main_rank` simultaneously (one per + # DPxEP cell), read-modify-write the same file, and any pair can interleave + # such that another rank reads a truncated mid-write file and raises + # JSONDecodeError. Bracket with barriers so every other rank waits for the + # single writer. + torch.distributed.barrier() config_json_file = save_directory + "/config.json" - if self._hf_quant_config and os.path.exists(config_json_file): + if ( + is_last_stage_main_rank + and get_data_parallel_rank() == 0 + and get_expert_model_parallel_rank() == 0 + and self._hf_quant_config + and os.path.exists(config_json_file) + ): with open(config_json_file) as f: config_dict = json.load(f) config_dict["quantization_config"] = convert_hf_quant_config_format( @@ -410,6 +450,7 @@ def save_pretrained( ) with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) + torch.distributed.barrier() # save_safetensors(state_dict, save_directory) save_safetensors_by_layer_index( @@ -614,7 +655,7 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]: ) single_safetensors_file = None except EntryNotFoundError: - # Model uses a single unsharded safetensors file — check it for MTP weights. + # Model uses a single unsharded safetensors file -- check it for MTP weights. safetensors_index_file = None try: single_safetensors_file = Path( @@ -1040,6 +1081,17 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): quantization state from the module, then iterates over experts and saves each expert's weight (and scales if quantized) under the HF-style per-expert prefix. + Collective behavior (EP > 1): + When the expert-model-parallel world size is greater than 1, ``num_experts`` is the + number of *local* experts on this rank, and ``range(num_experts)`` numbers them + 0..N-1, which collides with every other EP rank. This method translates local + expert ids to *global* ids (preferring ``module.local_expert_indices`` when + exposed, falling back to the standard Megatron contiguous layout + ``[ep_rank*N, (ep_rank+1)*N - 1]``) and ``all_gather_object`` the resulting + per-expert state across the EP process group so every rank in the EP group ends + up with the union of experts for this layer. All ranks in the EP group MUST call + this function for the same layer at the same logical step or the gather will hang. + This is the reverse of _grouped_mlp_merging in the importer. """ num_experts = module.num_gemms @@ -1062,37 +1114,134 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): state_dict = module.state_dict() - for expert_id in range(num_experts): - expert_prefix = prefix.format(expert_id) + "." - self._record_layer_quant_config(expert_prefix, qformat, block_size) - weight_key = f"weight{expert_id}" + ep_size = ( + get_expert_model_parallel_world_size() if torch.distributed.is_initialized() else 1 + ) + ep_rank = get_expert_model_parallel_rank() if torch.distributed.is_initialized() else 0 + + # Prefer the model's authoritative local-to-global expert mapping if exposed. + # The standard Megatron contiguous layout (rank r owns experts [r*N, (r+1)*N - 1]) + # is the fallback when the module doesn't expose local_expert_indices. + # + # Normalize to a plain Python list of ints regardless of whether Megatron exposes + # this as a list, tuple, torch.Tensor, or numpy array. A bare `or` against the raw + # attribute trips `bool(tensor)` for multi-element tensors, and an empty list (or + # 0-d tensor) would silently fall through to the contiguous-layout fallback. + indices = getattr(module, "local_expert_indices", None) + if indices is None and getattr(module, "experts", None) is not None: + indices = getattr(module.experts, "local_expert_indices", None) + if indices is None: + local_expert_indices = [ep_rank * num_experts + i for i in range(num_experts)] + elif isinstance(indices, torch.Tensor): + local_expert_indices = indices.detach().cpu().tolist() + else: + local_expert_indices = [int(i) for i in indices] + if len(local_expert_indices) != num_experts: + raise ValueError( + f"local_expert_indices length {len(local_expert_indices)} doesn't match " + f"module.num_gemms {num_experts}" + ) - if weight_key not in state_dict: - raise ValueError(f"Missing expected TEGroupedMLP expert weight: {weight_key}") + # Collective-safe missing-key detection: every EP rank checks for its weight keys + # locally, then we all_reduce(MAX) over a 0/1 flag so a single rank's missing key + # surfaces on every rank. Tensor must be on the current CUDA device -- NCCL has no + # CPU backend on the EP process group. + local_missing = [ + k for k in (f"weight{i}" for i in range(num_experts)) if k not in state_dict + ] + if ep_size > 1: + missing_flag = torch.tensor( + [1 if local_missing else 0], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + torch.distributed.all_reduce( + missing_flag, + op=torch.distributed.ReduceOp.MAX, + group=get_expert_model_parallel_group(), + ) + if missing_flag.item() != 0: + raise ValueError( + f"TEGroupedMLP missing expert weights on at least one EP rank " + f"(local missing on rank {ep_rank}: {local_missing})" + ) + elif local_missing: + raise ValueError(f"TEGroupedMLP missing expert weights: {local_missing}") + + # Scales / aux values are small and shared across local experts. Move to CPU once + # so the gather payload doesn't repeatedly clone GPU tensors. + weight_scale_cpu = weight_scale.detach().cpu().clone() if weight_scale is not None else None + weight_scale_2_cpu = ( + weight_scale_2.detach().cpu().clone() if weight_scale_2 is not None else None + ) + name_to_value_cpu = { + k: v.detach().cpu().clone() for k, v in name_to_value.items() if k != "output_scale" + } + + # Record per-layer quant config for ALL global experts on every rank, not just + # the local slice -- `_per_layer_quant_config` is a local in-memory dict that + # later gets serialized into hf_quant_config.json; if we only recorded the + # local 1/EP slice, the writer rank's dict would be missing (EP-1)/EP of the + # routed-expert entries. Within a single TEGroupedMLP layer all routed experts + # share the same qformat / block_size by construction (one quantizer pattern + # in the recipe matches the whole `*mixer.experts.*` glob), so it's safe to + # use the local qformat/block_size for the global record. + num_total_experts = num_experts * ep_size + for global_id in range(num_total_experts): + self._record_layer_quant_config(prefix.format(global_id) + ".", qformat, block_size) + + local_expert_state: dict[str, torch.Tensor] = {} + + for local_id in range(num_experts): + global_id = local_expert_indices[local_id] + expert_prefix = prefix.format(global_id) + "." + weight_key = f"weight{local_id}" weight = state_dict[weight_key].to(self.dtype).cpu() - if weight_scale is None: - self._state_dict[expert_prefix + "weight"] = weight + if weight_scale_cpu is None: + local_expert_state[expert_prefix + "weight"] = weight else: - self._state_dict[expert_prefix + "weight"] = to_quantized_weight( + local_expert_state[expert_prefix + "weight"] = to_quantized_weight( weight, - weight_scale, + weight_scale_cpu, qformat, - weight_scale_2, + weight_scale_2_cpu, block_size, ) - self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone() - - if weight_scale_2 is not None: - self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone() - - for key, val in name_to_value.items(): - if key == "output_scale": - continue - for expert_id in range(num_experts): - expert_prefix = prefix.format(expert_id) + "." - self._state_dict[expert_prefix + key] = val.detach().clone() + local_expert_state[expert_prefix + "weight_scale"] = weight_scale_cpu.clone() + + if weight_scale_2_cpu is not None: + local_expert_state[expert_prefix + "weight_scale_2"] = weight_scale_2_cpu.clone() + + for key, val in name_to_value_cpu.items(): + local_expert_state[expert_prefix + key] = val.clone() + + if ep_size > 1: + # all_gather_object can pickle Python objects but trips on quantized uint8 + # tensors whose UntypedStorage has no dtype attr. Round-trip through + # torch.save's byte stream -- PyTorch's own tensor codec handles them. + _buf = io.BytesIO() + torch.save(local_expert_state, _buf) + local_bytes = _buf.getvalue() + del _buf + gathered_bytes: list = [None] * ep_size + torch.distributed.all_gather_object( + gathered_bytes, local_bytes, group=get_expert_model_parallel_group() + ) + del local_bytes + for b in gathered_bytes: + # weights_only=False: payload is generated by us via torch.save in this + # same function on a sibling rank in the EP process group of this job + # -- it never leaves the cluster's collective and is not user-supplied. + # weights_only=True would refuse to deserialize the dict[str, Tensor] + # because quantized uint8 tensors store custom storage metadata that + # the safe-loader allowlist doesn't cover. + s = torch.load(io.BytesIO(b), map_location="cpu", weights_only=False) + self._state_dict.update(s) + del gathered_bytes + else: + self._state_dict.update(local_expert_state) def _qkv_slicing( self,