Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,14 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike):
meta_filename = filename + ".json"
ckpt_filename = filename + ".safetensors"

# Freeze a per-shard snapshot so safetensors and JSON are emitted from the same view.
frozen_tensors = dict(tensors)
save_file(frozen_tensors, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

weight_map = {}
local_total_size = 0

for key, val in tensors.items():
for key, val in frozen_tensors.items():
local_total_size += val.numel() * val.element_size()
weight_map[key] = ckpt_filename

Expand All @@ -264,7 +268,6 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike):
f,
indent=4,
)
save_file(tensors, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

# Barrier to ensure all ranks have written the metadata
torch.distributed.barrier()
Expand Down Expand Up @@ -305,9 +308,17 @@ def save_safetensors_by_layer_index(
meta_filename = filename + ".json"
ckpt_filename = filename + ".safetensors"

# Freeze a per-layer snapshot so safetensors and JSON are emitted from the same view.
frozen_layer_state_dict = dict(layer_state_dict)
save_file(
frozen_layer_state_dict,
save_directory + "/" + ckpt_filename,
metadata={"format": "pt"},
)

weight_map = {}
layer_total_size = 0
for key, val in layer_state_dict.items():
for key, val in frozen_layer_state_dict.items():
tensor_size = val.numel() * val.element_size()
layer_total_size += tensor_size
weight_map[key] = ckpt_filename
Expand All @@ -318,7 +329,6 @@ def save_safetensors_by_layer_index(
f,
indent=4,
)
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

# [TODO]: this global barrier needs to be replaced with something safer
torch.distributed.barrier()
Expand Down
73 changes: 73 additions & 0 deletions tests/gpu_megatron/torch/export/test_unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,79 @@ def test_qkv_slicing_gqa_tp2(dist_workers_size_2, tmp_path):
dist_workers_size_2.run(partial(_test_qkv_slicing_gqa_tp2, tmp_path))


def _test_export_pp2_mtp_metadata_matches_shards(tmp_path, model_dir, rank, size):
"""With PP>1, per-shard JSON keys should exist in the referenced safetensors shard."""
config = transformers.AutoConfig.from_pretrained(model_dir)

model = get_mcore_gpt_model(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=size,
initialize_megatron=True,
num_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_query_groups=config.num_key_value_heads,
ffn_hidden_size=config.intermediate_size,
max_sequence_length=config.max_position_embeddings,
vocab_size=config.vocab_size,
activation_func="swiglu",
normalization="RMSNorm",
transformer_impl="modelopt",
).cuda()

export_dir = tmp_path / "export_pp2"
original_get_mtp_state_dict = GPTModelExporter._get_mtp_state_dict

# Simulate stage-local MTP tensors (only on the last PP rank).
def _fake_get_mtp_state_dict(self):
if rank != size - 1:
return {}
return {f"mtp.injected.rank{rank}.weight": torch.ones(8, dtype=torch.bfloat16).cpu()}

GPTModelExporter._get_mtp_state_dict = _fake_get_mtp_state_dict

try:
export_mcore_gpt_to_hf(
model,
model_dir,
dtype=torch.bfloat16,
export_dir=str(export_dir),
)
finally:
GPTModelExporter._get_mtp_state_dict = original_get_mtp_state_dict

if rank == 0:
shard_json_files = sorted(export_dir.glob("model-*.json"))
assert shard_json_files, "no per-shard metadata json files found"

shard_keys_cache = {}
all_weight_map_keys = set()
for shard_json_file in shard_json_files:
with open(shard_json_file) as f:
shard_meta = json.load(f)
for key, shard_file in shard_meta["weight_map"].items():
all_weight_map_keys.add(key)
if shard_file not in shard_keys_cache:
with safe_open(
str(export_dir / shard_file), framework="pt", device="cpu"
) as sf:
shard_keys_cache[shard_file] = set(sf.keys())
assert key in shard_keys_cache[shard_file], (
f"key '{key}' from {shard_json_file.name} missing in {shard_file}"
)

assert any(key.startswith("mtp.injected.") for key in all_weight_map_keys), (
"expected injected mtp.* key missing from shard metadata/index map"
)


def test_unified_export_megatron_pp2_mtp_metadata_matches_shards(dist_workers_size_2, tmp_path):
model_dir = create_tiny_llama_dir(tmp_path)
dist_workers_size_2.run(
partial(_test_export_pp2_mtp_metadata_matches_shards, tmp_path, model_dir)
)


def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv():
"""Unquantized fused MCore linear_qkv should become HF q/k/v excludes."""
exporter = object.__new__(GPTModelExporter)
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/torch/export/test_mcore_save_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

import torch

from modelopt.torch.export.plugins import mcore_custom


def test_save_safetensors_by_layer_index_uses_single_snapshot(monkeypatch, tmp_path):
Comment thread
jinhangchoi marked this conversation as resolved.
"""Shard JSON and safetensors must be produced from the same key snapshot."""
layer_state_dicts = {1: {"layers.0.weight": torch.arange(4, dtype=torch.float32)}}
late_key = "mtp.late.weight"
written_keys_by_shard = {}

def _fake_save_file(tensors, path, metadata=None):
written_keys_by_shard[path.split("/")[-1]] = set(tensors.keys())
# Simulate a late mutation against the original source dict (not the writer snapshot).
layer_state_dicts[1][late_key] = torch.ones(1, dtype=torch.float32)

monkeypatch.setattr(mcore_custom, "save_file", _fake_save_file)
monkeypatch.setattr(torch.distributed, "barrier", lambda: None)
monkeypatch.setattr(torch.distributed, "get_rank", lambda: 0)

mcore_custom.save_safetensors_by_layer_index(
layer_state_dicts=layer_state_dicts,
total_layers=1,
Comment thread
jinhangchoi marked this conversation as resolved.
save_directory=str(tmp_path),
name_template="model-{:05d}-of-{:05d}",
)

shard_name = "model-00001-of-00001.safetensors"
with open(tmp_path / "model-00001-of-00001.json") as f:
shard_meta = json.load(f)
with open(tmp_path / "model.safetensors.index.json") as f:
index_meta = json.load(f)

json_keys = set(shard_meta["weight_map"].keys())
assert json_keys == written_keys_by_shard[shard_name]
assert late_key not in json_keys
assert late_key not in index_meta["weight_map"]
Loading