Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] return_config=True now extracts just config, not full tarfile #6346

Merged
merged 1 commit into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def restore_from(
loaded_params = super().load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
)
if not isinstance(loaded_params, tuple):
if not isinstance(loaded_params, tuple) or return_config is True:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(loaded_params, tuple) or return_config is True:
if not isinstance(loaded_params, tuple) or return_config:

Could we not just simplify to this? Seems it's always a bool

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer explicit checks generally, even for bool assertion

return loaded_params
conf, instance, state_dict = loaded_params
state_dict = self.modify_state_dict(conf, state_dict)
Expand Down
14 changes: 10 additions & 4 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def load_config_and_state_dict(

else:
# Extract the nemo file into the temporary directory
self._unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
self._unpack_nemo_file(
path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True
)

# Change current working directory to
os.chdir(tmpdir)
Expand Down Expand Up @@ -239,7 +241,7 @@ def restore_from(
loaded_params = self.load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
)
if not isinstance(loaded_params, tuple):
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
conf, instance, state_dict = loaded_params
state_dict = self.modify_state_dict(conf, state_dict)
Expand Down Expand Up @@ -532,7 +534,7 @@ def _make_nemo_file_from_folder(filename, source_dir):
tar.add(source_dir, arcname=".")

@staticmethod
def _unpack_nemo_file(path2file: str, out_folder: str) -> str:
def _unpack_nemo_file(path2file: str, out_folder: str, extract_config_only: bool = False) -> str:
if not os.path.exists(path2file):
raise FileNotFoundError(f"{path2file} does not exist")

Expand All @@ -546,7 +548,11 @@ def _unpack_nemo_file(path2file: str, out_folder: str) -> str:
# can be older checkpoint => try compressed tar
tar_header = "r:gz"
tar = tarfile.open(path2file, tar_header)
tar.extractall(path=out_folder)
if not extract_config_only:
tar.extractall(path=out_folder)
else:
members = [x for x in tar.getmembers() if ".yaml" in x.name]
tar.extractall(path=out_folder, members=members)
tar.close()
return out_folder

Expand Down
63 changes: 63 additions & 0 deletions tests/core/test_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,69 @@ class MockModelV2(MockModel):
assert type(restored_model) == MockModelV2
assert type(restored_model._save_restore_connector) == MySaveRestoreConnector

@pytest.mark.unit
def test_restore_from_save_restore_connector_return_config(self):
class MySaveRestoreConnector(save_restore_connector.SaveRestoreConnector):
def save_to(self, model, save_path: str):
save_path = save_path.replace(".nemo", "_XYZ.nemo")
super().save_to(model, save_path)

class MockModelV2(MockModel):
pass

with tempfile.TemporaryDirectory() as tmpdir:
# Update config
cfg = _mock_model_config()

# Create model
save_path = os.path.join(tmpdir, 'save_custom.nemo')
model_with_custom_connector = MockModel(cfg=cfg.model, trainer=None)
model_with_custom_connector._save_restore_connector = MySaveRestoreConnector()
model_with_custom_connector.save_to(save_path)

assert os.path.exists(os.path.join(tmpdir, 'save_custom_XYZ.nemo'))

restored_model_cfg = MockModelV2.restore_from(
save_path.replace(".nemo", "_XYZ.nemo"),
save_restore_connector=MySaveRestoreConnector(),
return_config=True,
)
assert isinstance(restored_model_cfg, DictConfig)
assert model_with_custom_connector.cfg == restored_model_cfg

@pytest.mark.unit
def test_restore_from_save_restore_connector_return_config_partial_tar_extraction(self):
class MySaveRestoreConnector(save_restore_connector.SaveRestoreConnector):
def save_to(self, model, save_path: str):
save_path = save_path.replace(".nemo", "_XYZ.nemo")
super().save_to(model, save_path)

class MockModelV2(MockModel):

Check notice

Code scanning / CodeQL

Unused local variable

Variable MockModelV2 is not used.
pass

with tempfile.TemporaryDirectory() as tmpdir:
# Update config
cfg = _mock_model_config()

# Create model
save_path = os.path.join(tmpdir, 'save_custom.nemo')
model_with_custom_connector = MockModel(cfg=cfg.model, trainer=None)
model_with_custom_connector._save_restore_connector = MySaveRestoreConnector()
model_with_custom_connector.save_to(save_path)

true_save_path = os.path.join(tmpdir, 'save_custom_XYZ.nemo')
assert os.path.exists(true_save_path)

my_connector = MySaveRestoreConnector()

with tempfile.TemporaryDirectory() as config_tmpdir:
my_connector._unpack_nemo_file(true_save_path, out_folder=config_tmpdir, extract_config_only=True)
current_files = list(os.listdir(config_tmpdir))

assert len(current_files) == 1 # only config file should have been extracted, no pytorch params
config_filepath = current_files[0]
assert config_filepath.endswith(".yaml")

@pytest.mark.unit
def test_mock_model_model_collision(self):
# The usual pipeline is working just fine.
Expand Down