Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #130

Merged
merged 7 commits into from
Jul 29, 2024
12 changes: 7 additions & 5 deletions lighter/callbacks/writer/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class LighterFileWriter(LighterBaseWriter):
for a more permanent solution, it can be added to the `self.writers` dictionary.

Args:
directory (Union[str, Path]): Directory where the files should be written.
path (Union[str, Path]): Directory where the files should be written.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function.
Available writers: "tensor", "image", "video", "itk_nrrd", "itk_seg_nrrd", "itk_nifti".
A custom writer function must take two arguments: `path` and `tensor`, and write the tensor to the specified path.
`tensor` is a single tensor without the batch dimension.
"""

def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(directory, writer)
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(path, writer)

@property
def writers(self) -> Dict[str, Callable]:
Expand All @@ -49,9 +49,11 @@ def write(self, tensor: torch.Tensor, id: Union[int, str]) -> None:
tensor (Tensor): Tensor, without the batch dimension, to be written.
id (Union[int, str]): Identifier, used for file-naming.
"""
if not self.path.is_dir():
raise RuntimeError(f"LighterFileWriter expects a directory path, got {self.path}")

# Determine the path for the file based on prediction count. The suffix must be added by the writer function.
path = self.directory / str(id)
path.parent.mkdir(exist_ok=True, parents=True)
path = self.path / str(id)
# Write the tensor to the file.
self.writer(path, tensor)

Expand Down
3 changes: 2 additions & 1 deletion lighter/utils/dynamic_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def import_module_from_path(module_name: str, module_path: str) -> None:
# Based on https://stackoverflow.com/a/41595552.

if module_name in sys.modules:
raise ValueError(f"{module_name} has already been imported as module.")
logger.warning(f"{module_name} has already been imported as module.")
return

module_path = Path(module_path).resolve() / "__init__.py"
if not module_path.is_file():
Expand Down
2 changes: 1 addition & 1 deletion projects/cifar10/experiments/monai_bundle_prototype.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trainer:
logger: False
callbacks:
- _target_: lighter.callbacks.LighterFileWriter
directory: '$f"{@project}/predictions"'
path: '$f"{@project}/predictions"'
writer: tensor

system:
Expand Down
14 changes: 12 additions & 2 deletions tests/integration/test_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@
"fit",
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
)
),
( # Method name
"test",
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
),
( # Method name
"predict",
# Config fiile
"./projects/cifar10/experiments/monai_bundle_prototype.yaml",
),
],
)
@pytest.mark.slow
def test_trainer_method(method_name: str, config: str):
""" """
kwargs = {"config": [config, test_overrides]}

kwargs = {"config": [config, test_overrides]}
func_return = run(method_name, **kwargs)
assert func_return is None
1 change: 0 additions & 1 deletion tests/integration/test_overrides.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ trainer#fast_dev_run: True
trainer#accelerator: cpu
system#batch_size: 16
system#num_workers: 2
trainer#callbacks: null
Loading