Skip to content

Commit

Permalink
[Core] return_config=True now extracts just config, not full tarfile (N…
Browse files Browse the repository at this point in the history
…VIDIA#6346)

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: shane carroll <[email protected]>
  • Loading branch information
titu1994 authored and 1-800-BAD-CODE committed Apr 4, 2023
1 parent 93f9a93 commit b01fc88
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def restore_from(
loaded_params = super().load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
)
if not isinstance(loaded_params, tuple):
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
conf, instance, state_dict = loaded_params
state_dict = self.modify_state_dict(conf, state_dict)
Expand Down
14 changes: 10 additions & 4 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def load_config_and_state_dict(

else:
# Extract the nemo file into the temporary directory
self._unpack_nemo_file(path2file=restore_path, out_folder=tmpdir)
self._unpack_nemo_file(
path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True
)

# Change current working directory to
os.chdir(tmpdir)
Expand Down Expand Up @@ -239,7 +241,7 @@ def restore_from(
loaded_params = self.load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
)
if not isinstance(loaded_params, tuple):
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
conf, instance, state_dict = loaded_params
state_dict = self.modify_state_dict(conf, state_dict)
Expand Down Expand Up @@ -532,7 +534,7 @@ def _make_nemo_file_from_folder(filename, source_dir):
tar.add(source_dir, arcname=".")

@staticmethod
def _unpack_nemo_file(path2file: str, out_folder: str) -> str:
def _unpack_nemo_file(path2file: str, out_folder: str, extract_config_only: bool = False) -> str:
if not os.path.exists(path2file):
raise FileNotFoundError(f"{path2file} does not exist")

Expand All @@ -546,7 +548,11 @@ def _unpack_nemo_file(path2file: str, out_folder: str) -> str:
# can be older checkpoint => try compressed tar
tar_header = "r:gz"
tar = tarfile.open(path2file, tar_header)
tar.extractall(path=out_folder)
if not extract_config_only:
tar.extractall(path=out_folder)
else:
members = [x for x in tar.getmembers() if ".yaml" in x.name]
tar.extractall(path=out_folder, members=members)
tar.close()
return out_folder

Expand Down
63 changes: 63 additions & 0 deletions tests/core/test_save_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,69 @@ class MockModelV2(MockModel):
assert type(restored_model) == MockModelV2
assert type(restored_model._save_restore_connector) == MySaveRestoreConnector

@pytest.mark.unit
def test_restore_from_save_restore_connector_return_config(self):
class MySaveRestoreConnector(save_restore_connector.SaveRestoreConnector):
def save_to(self, model, save_path: str):
save_path = save_path.replace(".nemo", "_XYZ.nemo")
super().save_to(model, save_path)

class MockModelV2(MockModel):
pass

with tempfile.TemporaryDirectory() as tmpdir:
# Update config
cfg = _mock_model_config()

# Create model
save_path = os.path.join(tmpdir, 'save_custom.nemo')
model_with_custom_connector = MockModel(cfg=cfg.model, trainer=None)
model_with_custom_connector._save_restore_connector = MySaveRestoreConnector()
model_with_custom_connector.save_to(save_path)

assert os.path.exists(os.path.join(tmpdir, 'save_custom_XYZ.nemo'))

restored_model_cfg = MockModelV2.restore_from(
save_path.replace(".nemo", "_XYZ.nemo"),
save_restore_connector=MySaveRestoreConnector(),
return_config=True,
)
assert isinstance(restored_model_cfg, DictConfig)
assert model_with_custom_connector.cfg == restored_model_cfg

@pytest.mark.unit
def test_restore_from_save_restore_connector_return_config_partial_tar_extraction(self):
class MySaveRestoreConnector(save_restore_connector.SaveRestoreConnector):
def save_to(self, model, save_path: str):
save_path = save_path.replace(".nemo", "_XYZ.nemo")
super().save_to(model, save_path)

class MockModelV2(MockModel):
pass

with tempfile.TemporaryDirectory() as tmpdir:
# Update config
cfg = _mock_model_config()

# Create model
save_path = os.path.join(tmpdir, 'save_custom.nemo')
model_with_custom_connector = MockModel(cfg=cfg.model, trainer=None)
model_with_custom_connector._save_restore_connector = MySaveRestoreConnector()
model_with_custom_connector.save_to(save_path)

true_save_path = os.path.join(tmpdir, 'save_custom_XYZ.nemo')
assert os.path.exists(true_save_path)

my_connector = MySaveRestoreConnector()

with tempfile.TemporaryDirectory() as config_tmpdir:
my_connector._unpack_nemo_file(true_save_path, out_folder=config_tmpdir, extract_config_only=True)
current_files = list(os.listdir(config_tmpdir))

assert len(current_files) == 1 # only config file should have been extracted, no pytorch params
config_filepath = current_files[0]
assert config_filepath.endswith(".yaml")

@pytest.mark.unit
def test_mock_model_model_collision(self):
# The usual pipeline is working just fine.
Expand Down

0 comments on commit b01fc88

Please sign in to comment.