Safetensor metadata mismatch fix in Mcore export#1422
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughShard ChangesSafetensors Write Order
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/plugins/mcore_custom.py (1)
308-320:⚠️ Potential issue | 🟠 Major | ⚡ Quick winFreeze
layer_state_dictonce to fully eliminate metadata/file drift.This reorder helps, but you still read a live mutable dict twice. If
layer_state_dictchanges aftersave_file(...)and before the metadata loop,.jsoncan diverge from the written.safetensors.Proposed hardening
for layer_index, layer_state_dict in layer_state_dicts.items(): filename = name_template.format(layer_index, total_layers) meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" + # Freeze key->tensor mapping used by both outputs. + frozen_layer_state_dict = dict(layer_state_dict) + # Write safetensors first, then build the per-layer meta JSON from the same dict. # Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after # the dict was first constructed) must be captured by both files. Writing safetensors # first ensures the JSON is always consistent with what is physically on disk. - save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + save_file( + frozen_layer_state_dict, + save_directory + "/" + ckpt_filename, + metadata={"format": "pt"}, + ) weight_map = {} layer_total_size = 0 - for key, val in layer_state_dict.items(): + for key, val in frozen_layer_state_dict.items(): tensor_size = val.numel() * val.element_size() layer_total_size += tensor_size weight_map[key] = ckpt_filename🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/export/plugins/mcore_custom.py` around lines 308 - 320, layer_state_dict is mutated after being written which can cause metadata/file drift; snapshot it and use that immutable copy for both the safetensors write and the metadata loop. Specifically, create a frozen copy of layer_state_dict (e.g., snapshot = dict(layer_state_dict)) and pass snapshot to save_file(...) and iterate snapshot.items() when building weight_map/layer_total_size so save_file, weight_map, and layer_total_size are computed from the exact same data; reference save_file, layer_state_dict, weight_map, ckpt_filename in your change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@modelopt/torch/export/plugins/mcore_custom.py`:
- Around line 308-320: layer_state_dict is mutated after being written which can
cause metadata/file drift; snapshot it and use that immutable copy for both the
safetensors write and the metadata loop. Specifically, create a frozen copy of
layer_state_dict (e.g., snapshot = dict(layer_state_dict)) and pass snapshot to
save_file(...) and iterate snapshot.items() when building
weight_map/layer_total_size so save_file, weight_map, and layer_total_size are
computed from the exact same data; reference save_file, layer_state_dict,
weight_map, ckpt_filename in your change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8c7dee09-3fac-42b5-92ed-6c95d8e50462
📒 Files selected for processing (1)
modelopt/torch/export/plugins/mcore_custom.py
a08d952 to
693a118
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1422 +/- ##
==========================================
- Coverage 76.27% 76.01% -0.26%
==========================================
Files 489 489
Lines 54415 54417 +2
==========================================
- Hits 41504 41365 -139
- Misses 12911 13052 +141
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
19b366a to
8951edb
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Small reorder (write safetensors shard before its per-layer JSON) that mitigates a real symptom seen in MCore PTQ on Nemotron, but I'm not fully convinced it addresses the root cause and there are gaps:
- Root cause story is unclear from the code path I can see. In
unified_export_megatron.py::save_pretrained,_get_mtp_state_dict()is invoked synchronously beforesave_safetensors_by_layer_index, and the result is merged intolayer_state_dicts[num_layers]before the barrier. There's no obvious place in the current code where MTP state is added between the two writes inside the per-layer loop. If the real bug is that something is still mutatinglayer_state_dictbetweenjson.dumpandsave_file(or vice versa), then this reorder only flips which side wins — after this PR, a late mutation will cause the JSON to advertise keys that were never written to the.safetensorsshard, which is arguably worse than the original direction. A snapshot-then-write (e.g.weight_map/sizes computed once into local vars, then write safetensors and JSON from those, with no further dict reads) would actually be race-safe. Worth the author confirming what concretely mutateslayer_state_dictaftersave_filewas originally called. - Sibling function not updated.
save_safetensors(just above, used for the non-per-layer path) has the exact same JSON-then-safetensors ordering. If the ordering matters here, it presumably matters there too — please update both for consistency, or document why only the per-layer path is affected. - No test. The author checked "Did you write any new necessary tests? ❌". A regression test that exercises the MTP-included shard and asserts that every key in the per-layer index JSON is actually present in the corresponding
.safetensorswould lock this down. - No CHANGELOG entry (also unchecked).
|
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review focused on whether cjluo-nv's first two pieces of feedback have been addressed.
-
Feedback #1 — root cause / snapshot-then-write: NOT addressed. 💬 Author updated PR description and reordered the writes, but the underlying concern was that a bare reorder only flips which side "wins" a race — if
layer_state_dictis mutated betweensave_fileand the metadata loop, the JSON will now advertise keys that were never written to the.safetensorsshard (arguably worse than the original direction). The race-safe fix suggested by both cjluo-nv and CodeRabbit was to snapshot the dict once (e.g.frozen = dict(layer_state_dict)) and drive bothsave_fileand theweight_maploop from that frozen view. The current diff atmcore_custom.py:308-329still callssave_file(layer_state_dict, ...)and then iterates the livelayer_state_dict.items()— same insave_safetensorsat lines 254-268 withtensors. So if the author's hypothesis about late mutation is correct, the consistency hole is still open; if it's wrong, it'd be good to hear what concretely was observed to drift on Nemotron 3 Ultra. Worth a human reviewer pushing on this before merge. -
Feedback #2 — sibling
save_safetensorsnot updated: addressed. 💬 The new commit applies the same reorder tosave_safetensors(lines 254-256), so both paths are consistent. (The snapshot caveat from #1 applies to both, though.)
For completeness on the other items the author replied to: regression test is still missing (author says it's hard to design), and CHANGELOG was declined as "bugfix doesn't need it" — leaving those as-is for the human owner to weigh in on.
|
I agree Under PP>1, MTP is stage-local (typically last-stage only), and staggered shard emission makes this ordering sensitivity more visible. |
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: sibling-function fix landed, but the core snapshot-then-write concern is still open and worth a human sign-off before merge.
- 💬 Author replied (PR comment): "snapshot-then-write is the generic race-safe pattern if in-loop mutation is expected... the observed MCore issue is better characterized as an ordering/coherence problem, not a proven same-process mutation between
save_fileand the metadata loop... explicit finalize-then-serialize is the stronger follow-up." — still warrants human sign-off because the current diff atmcore_custom.py:311-329(and the matchingsave_safetensorsat 254-268) still passes the livelayer_state_dict/tensorstosave_fileand then re-iterates the same live dict to buildweight_map. If the author's "no in-loop mutator" assumption holds, the reorder alone fixes the symptom; if it doesn't, the new ordering means the JSON can advertise keys that aren't in the.safetensorsshard — arguably worse than the original direction. A two-linefrozen = dict(layer_state_dict)snapshot would close this regardless of which hypothesis is right, and the author already agrees it's the more defensible pattern. - 💬 Sibling
save_safetensorsreorder: addressed in this revision (lines 254-256). Both paths are now consistent. - 💬 Author replied: "regression test is a little bit difficult to design. Need some brainstorm." — flagging anyway because the hypothesized failure mode (per-layer JSON listing keys not present in the shard) is straightforward to assert in a unit test by mocking
save_file/mutating the dict between writes; a real distributed Nemotron repro isn't required. - 💬 Author replied: "as bugfix, I don't think we need CHANGELOG here." — leaving for the human owner to weigh in on per project policy.
|
I would like to emphasize this is not a hypothetical issue. It was observed in production export artifacts. for key, val in layer_state_dict.items():
weight_map[key] = ckpt_filename
with open(meta_filename, "w") as f:
json.dump({"metadata": ..., "weight_map": weight_map}, f, indent=4)
save_file(layer_state_dict, ckpt_filename, metadata=...) # <-- runs LAST
The direction in your comment is inverted. By construction of the previous code before fix, the failure is "the JSON lists fewer keys than the shard contains." The shard is the superset, the index is a subset. The comment describes the opposite (JSON listing keys that aren't in the shard), which the previous code before fix can't produce. Regarding the regression test, I am supportive of adding a unit test along the lines you described, but to be clear about what it covers. This would be a freeze-invariant test, not a true regression for the production failure. The production failure requires PP>1 plus the MTP load path that mutates The actual end-to-end repro is the original MCore export that motivated this fix. I would like to call this out so we don't oversell what the test asserts. |
88b532f to
937eb21
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review.
Addressed since the previous round:
- 💬 Snapshot-then-write applied in both
save_safetensors(lines 254-256) andsave_safetensors_by_layer_index(lines 311-317).frozen_tensors = dict(tensors)/frozen_layer_state_dict = dict(layer_state_dict)is now the single source for bothsave_fileand theweight_maploop, so JSON and shard cannot diverge regardless of late mutations. This closes the main concern from cjluo-nv and CodeRabbit. - 💬 Sibling
save_safetensorsreorder + freeze: applied. - 💬 Author replied that a regression test is hard to design — a unit test was nonetheless added in
tests/unit/torch/export/test_mcore_save_safetensors.pythat mockssave_fileand mutates the source dict mid-write to assert the JSON does not pick up the late key. This is exactly the freeze-invariant test that was suggested. 👍
Still warrants human sign-off:
- The new test passes
tmp_path(apathlib.Path) assave_directory, but the production code doessave_directory + "/" + ckpt_filename, which isPath + strand raisesTypeError. The test as written looks like it would fail at the firstsave_filecall before reaching the assertions. Either I'm missing something or the test wasn't run — please double-check by either casting tostr(tmp_path)in the test or switching the source toos.path.join/Path /. (Same shape issue would also bitesave_safetensorsif it were ever exercised this way.) - 💬 Author replied: "as bugfix, I don't think we need CHANGELOG here." — deferred to human per project policy.
937eb21 to
352c422
Compare
c35051c to
ef71f1d
Compare
|
/ok to test ef71f1d |
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
…hards for PP==2 in test_unified_export_megatron.py Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
ef71f1d to
799bce3
Compare
|
/ok to run 799bce3 |
|
/ok to test 799bce3 |
What does this PR do?
Type of change: Bug fix
In MCore export, shard metadata (
*.json) and shard weights (*.safetensors) are produced from mutable shard maps and can be generated from different views of the same dict. In real Nemotron 3 Ultra PTQ runs, I observed MTP-related drift where metadata and shard contents were not aligned. This is plausible because MTP is stage-local (typically last-stage only), so per-rank shard contents are intentionally asymmetric.The exact mutation interleaving is hard to prove from this code path alone, but the current implementation reads mutable shard maps across separate write steps, making metadata/weights consistency timing-sensitive. The issue is most visible with PP>1, where staggered per-shard writes widen the timing window between metadata and tensor-file generation.
This PR makes shard serialization deterministic in both paths:
.safetensorsfrom that snapshot,.jsonfrom the same snapshot.Apply this consistently to:
save_safetensors_by_layer_indexsave_safetensorsThis guarantees shard JSON and shard safetensors cannot diverge due to late dict mutations.
Usage
N/A
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit