diff --git a/src/megatron/bridge/models/hf_pretrained/state.py b/src/megatron/bridge/models/hf_pretrained/state.py index 6e4770ca10..b98e667628 100644 --- a/src/megatron/bridge/models/hf_pretrained/state.py +++ b/src/megatron/bridge/models/hf_pretrained/state.py @@ -795,9 +795,9 @@ def save_generator( ) # Final check on whether all original tensors were written. - unsaved_keys = all_expected_keys - all_saved_keys + unsaved_keys = all_expected_keys.intersection(all_saved_keys) if not unsaved_keys: - extra_keys = all_yielded_keys - all_expected_keys + extra_keys = all_yielded_keys.intersection(all_expected_keys) if extra_keys: print( f"\nSuccess: All tensors from the original checkpoint were written. "