Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add CAREamist predict to disk method #189

Closed
wants to merge 23 commits into from

Conversation

melisande-c
Copy link
Member

Description

  • What: Add a predict_to_disk method to the CAREamist class. It uses the PredictionWriterCallback; an instance of which has been added as an attribute of CAREamist.
  • Why: Allows users of the CAREamist class to write predictions to disk.
  • How:
    • Add an instance of PredictionWriterCallback as an attribute of CAREamist and append it to the list of callbacks.
    • The bulk of CAREamist.predict has been moved to CAREamist._predict_implementation; for reuse between predict and predict_to_disk but allow the parameter return_predictions to be set to True or False respectively.
    • PredictionWriterCallback has a context manager for toggling writing predictions on and off (added for aesthetic reasons), but it might also be useful for the lightning API users.
    • The file names for the predictions are taken from the dataset classes so they match the input files.

Getting the file names

This is what has caused the most trouble. I have come up with a solution that works, but I dislike it because it introduces a lot of coupling between the WriteStrategy classes and the dataset classes. Additionally, I can envision situations where it breaks down.

Currently, I have modified IterablePredDataset and IterableTiledPredDataset so that they store a list to keep track of the file to which a sample belongs. Then, in the WriteStrategy, if the index of the sample is known the file name can be retrieved from the dataset. The concrete implementations are as follows:

  • In the WriteImage strategy the index of the sample can be calculated from the batch index and the batch size.
  • In the CacheTiles strategy the sample index is kept track of with an internal attribute sample_count; this is bad because there is no reset mechanism and if the same instance of the CacheTiles object is used again it will fail.

Notes:

  • In the datasets I tried having just a current_file_index attribute, but the way lightning pre-fetches batches meant it didn't work.

Alternatives:

  • The datasets return the filename that the sample/tile belongs to. This will take some extra work because there are some places where a batch length of 2 is used to infer whether a prediction is tiled or not, which would obviously not hold anymore.

Changes Made

  • Added:
    • CAREamist.predict_to_disk method.
    • CAREamist._predict_implementation method.
    • CAREamist.prediction_writer attribute.
    • validate_unet_tile_size function (in prediction_utils package).
    • IterablePredDataset.sample_file_indices & IterableTiledPredDataset.sample_file_indices attributes.
    • Tests
  • Modified:
    • PredictionWriterCallback class.
    • CacheTiles class.
    • WriteImage class.
    • get_sample_file_path function (in lightning.callbacks.prediction_writer_callback.file_utils module).
    • IterablePredDataset.__iter__ & IterableTiledPredDataset.__iter__ methods.

Breaking changes

Shouldn't alter the behaviour of the rest of the code base.

Additional Notes and Examples

Currently if there is more than one sample in a file, say 2, so it has the shape (2, 32, 32) with the axes being "SYX" and the file name "image.tiff". Then, there will be a separate file saved for each sample prediction called "image_0.tiff" and "image_1.tiff". It is implemented this way because it was the path of least resistance, but if you think this will be too inconvenient or confusing to users then I can change it with a little more work by altering the WriteStrategy classes.

Also, I realised I am missing a couple tests, I will add them shortly.

Need opinions on:

  • A better way to get/generate the file names for the predictions. Should they be returned with the batches by the dataset classes? Or just leave it as it is for now even though it introduces a lot of coupling.
  • Should I update the implementation so the saved file structure mirrors the input data (instead of 1 file per sample).

Please ensure your PR meets the following requirements:

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features)

can now validate on set property
can now validate on set property
@veegalinova
Copy link
Collaborator

I think the suggestion to return the filename from the dataset is great.
We could use namedtuple or dict as they both work with collate_fn.

As an example, the return structure could look like this:

class DataItem(NamedTuple):
    data_patch: torch.Tensor
    tile_info: TileInformation
    file_path: Path

But I think it is out of scope for this PR and something to discuss later.

@veegalinova
Copy link
Collaborator

Should I update the implementation so the saved file structure mirrors the input data (instead of 1 file per sample)

I think the saved file should indeed be the same shape as the input. Otherwise, the user would need to combine them back together manually.
Is there a way to "cache" the samples until the S dimension is exhausted and then reshape it back to the input dimensions before writing? Or would it require too many changes in the current code?

@melisande-c
Copy link
Member Author

I think the suggestion to return the filename from the dataset is great.
We could use namedtuple or dict as they both work with collate_fn.

I like this, I have been thinking the returns from the datasets should probably be more explicit. There are places where the output is unpacked as x, *aux, but I think namedtuples can be unpacked, so this is ok. What will cause problems, however, is that I am pretty sure there are some places where an output's length is used to infer whether it is tiled or not.

I think I can work on a different PR to modify the datasets to return namedtuples. Once that is done, I can integrate it into this PR.

@melisande-c
Copy link
Member Author

Should I update the implementation so the saved file structure mirrors the input data (instead of 1 file per sample)

I think the saved file should indeed be the same shape as the input. Otherwise, the user would need to combine them back together manually. Is there a way to "cache" the samples until the S dimension is exhausted and then reshape it back to the input dimensions before writing? Or would it require too many changes in the current code?

This would be the route I would take I think. Where the WriteStrategy class gets the S dimension from is the problem. (Similar problem to where does it get the filename from).

src/careamics/careamist.py Show resolved Hide resolved
"""
# create and set correct write strategy
tiled = tile_size is not None
write_strategy = create_write_strategy(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be a bit more concise to pack the write strategy selection into the prediction_writer itself, something like self.prediction_writer.set_writing_strategy(write_type, ...) and create the write_strategy there directly?

Copy link
Member Author

@melisande-c melisande-c Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a deliberate design choice to reduce coupling between PredictionWriterCallback and the WriteStrategy classes. Removing all responsibility of WriteStrategy construction from PredictionWriterCallback means that the write strategies can be almost arbitrarily altered, with the constraint that they have a write_batch method, without having to alter the PredictionWriterCallback. This is the key aim the strategy pattern.

Copy link
Member Author

@melisande-c melisande-c Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I did break the pattern rules with the from_write_func_params method though 😅. I put this in for convenience of users of the lightning API, but I am now considering removing. The issue is that the create_write_strategy might have to change a bit because I don't know the parameters that the WriteZarrTiles strategy will need for initialisation. This means the arguments for from_write_func_params will also need to change and also those of a potential set_write_strategy method.

@melisande-c
Copy link
Member Author

Should I update the implementation so the saved file structure mirrors the input data (instead of 1 file per sample)

I think the saved file should indeed be the same shape as the input. Otherwise, the user would need to combine them back together manually. Is there a way to "cache" the samples until the S dimension is exhausted and then reshape it back to the input dimensions before writing? Or would it require too many changes in the current code?

This would be the route I would take I think. Where the WriteStrategy class gets the S dimension from is the problem. (Similar problem to where does it get the filename from).

Actually I've thought about it a bit more and the filename could be used to solve this problem. Samples will be "cached" until the filename changes and then they can be concatenated and saved.

@melisande-c
Copy link
Member Author

I will close this while I make some changes.

I need to add "sample caching" so that the prediction file structure matches that of the input file structure.

I also need to find a solution to how the prediction writer will access the file names, a simpler solution (that I was originally trying to avoid) is simply pass the filenames to the write strategy classes. I was trying to avoid this because it will be annoying for users not using the CAREamist api class. I will try and find a way to make it feel convenient.

Also as a small note, I will reiterate, the main motivation for this whole set up is to be able to save tiles to Zarr when potentially an entire file doesn't fit into memory. This feature probably won't be implemented for a while but I need to make sure it won't be difficult.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants