Skip to content

Commit

Permalink
Initialize ArtifactItem with dummy values to keep changes minimal
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
janekl committed Dec 18, 2023
1 parent 11aad57 commit 1a389bf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
16 changes: 9 additions & 7 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists
"""
app_state = AppState()

artifact_item = model_utils.ArtifactItem()

# This is for backward compatibility, if the src objects exists simply inside of the tarfile
# without its key having been overriden, this pathway will be used.
src_obj_name = os.path.basename(src)
Expand All @@ -370,18 +372,18 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists
# src is a local existing path - register artifact and return exact same path for usage by the model
if os.path.exists(os.path.abspath(src)):
return_path = os.path.abspath(src)
path_type = model_utils.ArtifactPathType.LOCAL_PATH
artifact_item.path_type = model_utils.ArtifactPathType.LOCAL_PATH

# this is the case when artifact must be retried from the nemo file
# we are assuming that the location of the right nemo file is available from _MODEL_RESTORE_PATH
elif src.startswith("nemo:"):
return_path = os.path.abspath(os.path.join(app_state.nemo_file_folder, src[5:]))
path_type = model_utils.ArtifactPathType.TAR_PATH
artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH

# backward compatibility implementation
elif os.path.exists(src_obj_path):
return_path = src_obj_path
path_type = model_utils.ArtifactPathType.TAR_PATH
artifact_item.path_type = model_utils.ArtifactPathType.TAR_PATH
else:
if verify_src_exists:
raise FileNotFoundError(
Expand All @@ -396,7 +398,7 @@ def register_artifact(self, model, config_path: str, src: str, verify_src_exists

assert os.path.exists(return_path)

artifact_item = model_utils.ArtifactItem(path=os.path.abspath(src), path_type=path_type,)
artifact_item.path = os.path.abspath(src)
model.artifacts[config_path] = artifact_item
# we were called by ModelPT
if hasattr(model, "cfg"):
Expand Down Expand Up @@ -487,9 +489,9 @@ def _handle_artifacts(self, model, nemo_file_folder):
shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name))

# Update artifacts registry
new_artiitem = model_utils.ArtifactItem(
path="nemo:" + artifact_uniq_name, path_type=model_utils.ArtifactPathType.TAR_PATH,
)
new_artiitem = model_utils.ArtifactItem()
new_artiitem.path = "nemo:" + artifact_uniq_name
new_artiitem.path_type = model_utils.ArtifactPathType.TAR_PATH
model.artifacts[conf_path] = new_artiitem
finally:
# change back working directory
Expand Down
4 changes: 2 additions & 2 deletions nemo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class ArtifactPathType(Enum):

@dataclass
class ArtifactItem:
path: str
path_type: ArtifactPathType
path: str = ""
path_type: ArtifactPathType = ArtifactPathType.LOCAL_PATH
hashed_path: Optional[str] = None


Expand Down

0 comments on commit 1a389bf

Please sign in to comment.