Skip to content

Commit

Permalink
Relative paths for BMZ
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Dec 19, 2023
1 parent 3af12f2 commit 88ba7b8
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions src/careamics/bioimage/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ def save_bioimage_model(
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# save chekpoint entries in separate files
weight_path = workdir.joinpath("model_weights.pth")
weight_path = Path("model_weights.pth")
torch.save(checkpoint["model_state_dict"], weight_path)

optim_path = workdir.joinpath("optim.pth")
optim_path = Path("optim.pth")
torch.save(checkpoint["optimizer_state_dict"], optim_path)

scheduler_path = workdir.joinpath("scheduler.pth")
scheduler_path = Path("scheduler.pth")
torch.save(checkpoint["scheduler_state_dict"], scheduler_path)

grad_path = workdir.joinpath("grad.pth")
grad_path = Path("grad.pth")
torch.save(checkpoint["grad_scaler_state_dict"], grad_path)

config_path = workdir.joinpath("config.pth")
config_path = Path("config.pth")
torch.save(config.model_dump(), config_path)

# create attachments
Expand All @@ -68,7 +68,7 @@ def save_bioimage_model(
]

# create requirements file
requirements = workdir.joinpath("requirements.txt")
requirements = Path("requirements.txt")
with open(requirements, "w") as f:
f.write("git+https://github.com/CAREamics/careamics.git")

Expand Down Expand Up @@ -100,14 +100,7 @@ def save_bioimage_model(
**specs,
)

# remove the temporary files
weight_path.unlink()
optim_path.unlink()
scheduler_path.unlink()
grad_path.unlink()
config_path.unlink()

# BMZ creates spurious files (copied before zipping)
# remove temporary files
for file in temp_folder.glob("*"):
file.unlink()

Expand Down

0 comments on commit 88ba7b8

Please sign in to comment.