From c88d891d7b8cb491f56cc622fa18c1aebcb7ea6e Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 29 Jul 2024 20:06:08 +0100 Subject: [PATCH] Fix LighterFileWriter (#130) * Bump version * Fix LighterFileWriter + add more test cases --------- Co-authored-by: GitHub Action --- lighter/callbacks/writer/file.py | 12 +++++++----- lighter/utils/dynamic_imports.py | 3 ++- .../experiments/monai_bundle_prototype.yaml | 2 +- tests/integration/test_cifar.py | 14 ++++++++++++-- tests/integration/test_overrides.yaml | 1 - 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/lighter/callbacks/writer/file.py b/lighter/callbacks/writer/file.py index 7e92c9d8..583df38a 100644 --- a/lighter/callbacks/writer/file.py +++ b/lighter/callbacks/writer/file.py @@ -20,15 +20,15 @@ class LighterFileWriter(LighterBaseWriter): for a more permanent solution, it can be added to the `self.writers` dictionary. Args: - directory (Union[str, Path]): Directory where the files should be written. + path (Union[str, Path]): Directory where the files should be written. writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function. Available writers: "tensor", "image", "video", "itk_nrrd", "itk_seg_nrrd", "itk_nifti". A custom writer function must take two arguments: `path` and `tensor`, and write the tensor to the specified path. `tensor` is a single tensor without the batch dimension. """ - def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None: - super().__init__(directory, writer) + def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None: + super().__init__(path, writer) @property def writers(self) -> Dict[str, Callable]: @@ -49,9 +49,11 @@ def write(self, tensor: torch.Tensor, id: Union[int, str]) -> None: tensor (Tensor): Tensor, without the batch dimension, to be written. id (Union[int, str]): Identifier, used for file-naming. """ + if not self.path.is_dir(): + raise RuntimeError(f"LighterFileWriter expects a directory path, got {self.path}") + # Determine the path for the file based on prediction count. The suffix must be added by the writer function. - path = self.directory / str(id) - path.parent.mkdir(exist_ok=True, parents=True) + path = self.path / str(id) # Write the tensor to the file. self.writer(path, tensor) diff --git a/lighter/utils/dynamic_imports.py b/lighter/utils/dynamic_imports.py index b54911cf..47ce3903 100644 --- a/lighter/utils/dynamic_imports.py +++ b/lighter/utils/dynamic_imports.py @@ -66,7 +66,8 @@ def import_module_from_path(module_name: str, module_path: str) -> None: # Based on https://stackoverflow.com/a/41595552. if module_name in sys.modules: - raise ValueError(f"{module_name} has already been imported as module.") + logger.warning(f"{module_name} has already been imported as module.") + return module_path = Path(module_path).resolve() / "__init__.py" if not module_path.is_file(): diff --git a/projects/cifar10/experiments/monai_bundle_prototype.yaml b/projects/cifar10/experiments/monai_bundle_prototype.yaml index 60350828..ad2dffcd 100644 --- a/projects/cifar10/experiments/monai_bundle_prototype.yaml +++ b/projects/cifar10/experiments/monai_bundle_prototype.yaml @@ -10,7 +10,7 @@ trainer: logger: False callbacks: - _target_: lighter.callbacks.LighterFileWriter - directory: '$f"{@project}/predictions"' + path: '$f"{@project}/predictions"' writer: tensor system: diff --git a/tests/integration/test_cifar.py b/tests/integration/test_cifar.py index f7f4c602..3cecdde1 100644 --- a/tests/integration/test_cifar.py +++ b/tests/integration/test_cifar.py @@ -14,13 +14,23 @@ "fit", # Config fiile "./projects/cifar10/experiments/monai_bundle_prototype.yaml", - ) + ), + ( # Method name + "test", + # Config fiile + "./projects/cifar10/experiments/monai_bundle_prototype.yaml", + ), + ( # Method name + "predict", + # Config fiile + "./projects/cifar10/experiments/monai_bundle_prototype.yaml", + ), ], ) @pytest.mark.slow def test_trainer_method(method_name: str, config: str): """ """ - kwargs = {"config": [config, test_overrides]} + kwargs = {"config": [config, test_overrides]} func_return = run(method_name, **kwargs) assert func_return is None diff --git a/tests/integration/test_overrides.yaml b/tests/integration/test_overrides.yaml index 46ad9c22..d44c0e78 100644 --- a/tests/integration/test_overrides.yaml +++ b/tests/integration/test_overrides.yaml @@ -2,4 +2,3 @@ trainer#fast_dev_run: True trainer#accelerator: cpu system#batch_size: 16 system#num_workers: 2 -trainer#callbacks: null