From 2b1ef75dccb7497a10bb5f64c9a71059f83ce914 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Fri, 8 May 2026 01:28:25 -0700 Subject: [PATCH 1/7] safetensor metadata mismatch fix Signed-off-by: Jinhang Choi --- modelopt/torch/export/plugins/mcore_custom.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 0b6ce7a35df..35e45109e8f 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -305,6 +305,12 @@ def save_safetensors_by_layer_index( meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" + # Write safetensors first, then build the per-layer meta JSON from the same dict. + # Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after + # the dict was first constructed) must be captured by both files. Writing safetensors + # first ensures the JSON is always consistent with what is physically on disk. + save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + weight_map = {} layer_total_size = 0 for key, val in layer_state_dict.items(): @@ -318,7 +324,6 @@ def save_safetensors_by_layer_index( f, indent=4, ) - save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) # [TODO]: this global barrier needs to be replaced with something safer torch.distributed.barrier() From 9ac2c98f69704f64dbff6e546e71b6abf360a2f7 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Wed, 13 May 2026 09:33:31 -0700 Subject: [PATCH 2/7] Feedback: concise comment for safetensor ordering Signed-off-by: Jinhang Choi --- modelopt/torch/export/plugins/mcore_custom.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 35e45109e8f..73779627685 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -305,11 +305,13 @@ def save_safetensors_by_layer_index( meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" - # Write safetensors first, then build the per-layer meta JSON from the same dict. - # Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after - # the dict was first constructed) must be captured by both files. Writing safetensors - # first ensures the JSON is always consistent with what is physically on disk. - save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + # Keep safetensors-first ordering so late layer_state_dict updates are captured + # in the layer .safetensors shard and reflected in its per-layer metadata JSON. + save_file( + layer_state_dict, + save_directory + "/" + ckpt_filename, + metadata={"format": "pt"}, + ) weight_map = {} layer_total_size = 0 From 9e5f3ad25f653c01bb55017a52e3727ce3ccec51 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Thu, 4 Jun 2026 17:21:04 -0700 Subject: [PATCH 3/7] Feedback: update save_safetensors too Signed-off-by: Jinhang Choi --- modelopt/torch/export/plugins/mcore_custom.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 73779627685..cf1583ab858 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -251,6 +251,10 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike): meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" + # Keep safetensors-first ordering so late shard updates are reflected + # in both the shard file and the shard metadata JSON. + save_file(tensors, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + weight_map = {} local_total_size = 0 @@ -264,7 +268,6 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike): f, indent=4, ) - save_file(tensors, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) # Barrier to ensure all ranks have written the metadata torch.distributed.barrier() From 6f51d11f0ef5740c57f379926d1f055bf7c7ebd4 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Fri, 5 Jun 2026 01:02:46 -0700 Subject: [PATCH 4/7] Feedback: freeze snapshot Signed-off-by: Jinhang Choi --- modelopt/torch/export/plugins/mcore_custom.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index cf1583ab858..f231da48433 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -251,14 +251,14 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike): meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" - # Keep safetensors-first ordering so late shard updates are reflected - # in both the shard file and the shard metadata JSON. - save_file(tensors, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + # Freeze a per-shard snapshot so safetensors and JSON are emitted from the same view. + frozen_tensors = dict(tensors) + save_file(frozen_tensors, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) weight_map = {} local_total_size = 0 - for key, val in tensors.items(): + for key, val in frozen_tensors.items(): local_total_size += val.numel() * val.element_size() weight_map[key] = ckpt_filename @@ -308,17 +308,17 @@ def save_safetensors_by_layer_index( meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" - # Keep safetensors-first ordering so late layer_state_dict updates are captured - # in the layer .safetensors shard and reflected in its per-layer metadata JSON. + # Freeze a per-layer snapshot so safetensors and JSON are emitted from the same view. + frozen_layer_state_dict = dict(layer_state_dict) save_file( - layer_state_dict, + frozen_layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}, ) weight_map = {} layer_total_size = 0 - for key, val in layer_state_dict.items(): + for key, val in frozen_layer_state_dict.items(): tensor_size = val.numel() * val.element_size() layer_total_size += tensor_size weight_map[key] = ckpt_filename From 1c064a1185dc654f41d1701a6643014bff050de1 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Fri, 5 Jun 2026 01:24:39 -0700 Subject: [PATCH 5/7] Feedback: test_mcore_save_safetensors.py Signed-off-by: Jinhang Choi --- .../export/test_mcore_save_safetensors.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/unit/torch/export/test_mcore_save_safetensors.py diff --git a/tests/unit/torch/export/test_mcore_save_safetensors.py b/tests/unit/torch/export/test_mcore_save_safetensors.py new file mode 100644 index 00000000000..9f2050c405f --- /dev/null +++ b/tests/unit/torch/export/test_mcore_save_safetensors.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import torch + +from modelopt.torch.export.plugins import mcore_custom + + +def test_save_safetensors_by_layer_index_uses_single_snapshot(monkeypatch, tmp_path): + """Shard JSON and safetensors must be produced from the same key snapshot.""" + layer_state_dicts = {1: {"layers.0.weight": torch.arange(4, dtype=torch.float32)}} + late_key = "mtp.late.weight" + written_keys_by_shard = {} + + def _fake_save_file(tensors, path, metadata=None): + written_keys_by_shard[path.split("/")[-1]] = set(tensors.keys()) + # Simulate a late mutation against the original source dict (not the writer snapshot). + layer_state_dicts[1][late_key] = torch.ones(1, dtype=torch.float32) + + monkeypatch.setattr(mcore_custom, "save_file", _fake_save_file) + monkeypatch.setattr(torch.distributed, "barrier", lambda: None) + monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0) + + mcore_custom.save_safetensors_by_layer_index( + layer_state_dicts=layer_state_dicts, + total_layers=1, + save_directory=str(tmp_path), + name_template="model-{:05d}-of-{:05d}", + ) + + shard_name = "model-00001-of-00001.safetensors" + with open(tmp_path / "model-00001-of-00001.json") as f: + shard_meta = json.load(f) + with open(tmp_path / "model.safetensors.index.json") as f: + index_meta = json.load(f) + + json_keys = set(shard_meta["weight_map"].keys()) + assert json_keys == written_keys_by_shard[shard_name] + assert late_key not in json_keys + assert late_key not in index_meta["weight_map"] From 43473e9f7ce012a6f2cadcafdd240800a8dd3cc8 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Fri, 5 Jun 2026 12:45:59 -0700 Subject: [PATCH 6/7] Feedback: add test_unified_export_megatron_pp2_mtp_metadata_matches_shards for PP==2 in test_unified_export_megatron.py Signed-off-by: Jinhang Choi --- .../export/test_unified_export_megatron.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index f818cb3594c..f15bc735592 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -352,6 +352,75 @@ def test_qkv_slicing_gqa_tp2(dist_workers_size_2, tmp_path): dist_workers_size_2.run(partial(_test_qkv_slicing_gqa_tp2, tmp_path)) +def _test_export_pp2_mtp_metadata_matches_shards(tmp_path, model_dir, rank, size): + """With PP>1, per-shard JSON keys should exist in the referenced safetensors shard.""" + config = transformers.AutoConfig.from_pretrained(model_dir) + + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=config.num_hidden_layers, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_query_groups=config.num_key_value_heads, + ffn_hidden_size=config.intermediate_size, + max_sequence_length=config.max_position_embeddings, + vocab_size=config.vocab_size, + activation_func="swiglu", + normalization="RMSNorm", + transformer_impl="modelopt", + ).cuda() + + export_dir = tmp_path / "export_pp2" + original_get_mtp_state_dict = GPTModelExporter._get_mtp_state_dict + + # Simulate stage-local MTP tensors (only on the last PP rank). + def _fake_get_mtp_state_dict(self): + if rank != size - 1: + return {} + return {f"mtp.injected.rank{rank}.weight": torch.ones(8, dtype=torch.bfloat16).cpu()} + + GPTModelExporter._get_mtp_state_dict = _fake_get_mtp_state_dict + + try: + export_mcore_gpt_to_hf( + model, + model_dir, + dtype=torch.bfloat16, + export_dir=str(export_dir), + ) + finally: + GPTModelExporter._get_mtp_state_dict = original_get_mtp_state_dict + + if rank == 0: + shard_json_files = sorted(export_dir.glob("model-*.json")) + assert shard_json_files, "no per-shard metadata json files found" + + shard_keys_cache = {} + all_weight_map_keys = set() + for shard_json_file in shard_json_files: + with open(shard_json_file) as f: + shard_meta = json.load(f) + for key, shard_file in shard_meta["weight_map"].items(): + all_weight_map_keys.add(key) + if shard_file not in shard_keys_cache: + with safe_open(str(export_dir / shard_file), framework="pt", device="cpu") as sf: + shard_keys_cache[shard_file] = set(sf.keys()) + assert key in shard_keys_cache[shard_file], ( + f"key '{key}' from {shard_json_file.name} missing in {shard_file}" + ) + + assert any(key.startswith("mtp.injected.") for key in all_weight_map_keys), ( + "expected injected mtp.* key missing from shard metadata/index map" + ) + + +def test_unified_export_megatron_pp2_mtp_metadata_matches_shards(dist_workers_size_2, tmp_path): + model_dir = create_tiny_llama_dir(tmp_path) + dist_workers_size_2.run(partial(_test_export_pp2_mtp_metadata_matches_shards, tmp_path, model_dir)) + + def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): """Unquantized fused MCore linear_qkv should become HF q/k/v excludes.""" exporter = object.__new__(GPTModelExporter) From 799bce3cf985c29a4507c1ebc8e33c56424253a6 Mon Sep 17 00:00:00 2001 From: Jinhang Choi Date: Fri, 5 Jun 2026 15:20:24 -0700 Subject: [PATCH 7/7] linter: style fix Signed-off-by: Jinhang Choi --- .../torch/export/test_unified_export_megatron.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index f15bc735592..e4d1968d4ef 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -405,7 +405,9 @@ def _fake_get_mtp_state_dict(self): for key, shard_file in shard_meta["weight_map"].items(): all_weight_map_keys.add(key) if shard_file not in shard_keys_cache: - with safe_open(str(export_dir / shard_file), framework="pt", device="cpu") as sf: + with safe_open( + str(export_dir / shard_file), framework="pt", device="cpu" + ) as sf: shard_keys_cache[shard_file] = set(sf.keys()) assert key in shard_keys_cache[shard_file], ( f"key '{key}' from {shard_json_file.name} missing in {shard_file}" @@ -418,7 +420,9 @@ def _fake_get_mtp_state_dict(self): def test_unified_export_megatron_pp2_mtp_metadata_matches_shards(dist_workers_size_2, tmp_path): model_dir = create_tiny_llama_dir(tmp_path) - dist_workers_size_2.run(partial(_test_export_pp2_mtp_metadata_matches_shards, tmp_path, model_dir)) + dist_workers_size_2.run( + partial(_test_export_pp2_mtp_metadata_matches_shards, tmp_path, model_dir) + ) def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv():