Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add visualisation callback for image classification #228

Merged
merged 30 commits into from
Apr 22, 2021
Merged

Conversation

edgarriba
Copy link
Contributor

What does this PR do?

This PR implements the feature request from #74 and implements ImageClassificationDataVisualizer
to visualize the images and its associated labels right before feeding to the model.

Usage proposal:

   data_viz = ImageClassificationDataVisualizer.from_filepaths(
        train_filepaths=["path/img1.png", "path/img2.png"],
        train_labels=[0, 1],
        batch_size=2,
    )
    data_viz.show_train_batch()

Result obtained from the unittest:
image

TBD:

  • Exact information to be shown
  • Do we require at all matplotlib ?

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

)
data_viz.show_train_batch()
data_viz.show_val_batch()
data_viz.show_test_batch()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to decide how we validate this functionality during tests since it involves matplotlib visualization

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flash/vision/classification/data.py Outdated Show resolved Hide resolved
flash/vision/classification/data.py Outdated Show resolved Hide resolved
flash/vision/classification/data.py Outdated Show resolved Hide resolved
)
data_viz.show_train_batch()
data_viz.show_val_batch()
data_viz.show_test_batch()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edgarriba edgarriba marked this pull request as draft April 19, 2021 15:52
@pep8speaks
Copy link

pep8speaks commented Apr 20, 2021

Hello @edgarriba! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-04-22 11:54:46 UTC

@codecov
Copy link

codecov bot commented Apr 20, 2021

Codecov Report

Merging #228 (f05c0d4) into master (7bfa80d) will increase coverage by 0.09%.
The diff coverage is 90.62%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #228      +/-   ##
==========================================
+ Coverage   86.81%   86.90%   +0.09%     
==========================================
  Files          58       58              
  Lines        2981     3055      +74     
==========================================
+ Hits         2588     2655      +67     
- Misses        393      400       +7     
Flag Coverage Δ
unittests 86.90% <90.62%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/vision/classification/data.py 87.98% <88.40%> (-0.23%) ⬇️
flash/data/base_viz.py 96.00% <88.88%> (-4.00%) ⬇️
flash/data/data_module.py 79.38% <100.00%> (+1.45%) ⬆️
flash/utils/imports.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7bfa80d...f05c0d4. Read the comment docs.

@edgarriba edgarriba marked this pull request as ready for review April 20, 2021 23:29
@edgarriba edgarriba requested a review from tchaton April 21, 2021 07:47
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM ! Small changes needed.

@@ -92,6 +92,9 @@ def __init__(
# this may also trigger data preloading
self.set_running_stages()

# buffer to store the functions to visualise
self._fcn_white_list: Dict[str, Set[str]] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable name isn't clean. Mind finding a name related to viz.

data_fetcher._show(stage)
if reset:
self.viz.batches[stage] = {}

def show_train_batch(self, reset: bool = True) -> None:
def show_train_batch(self, name: str = 'load_sample', reset: bool = True) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List of hook names should also be supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can provide the following api:

def show_train_batch(self, name: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
    ...

"""This function is used to visualize a batch from the train dataloader."""
self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset)
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this logic to _show_batch to reduce duplicated code and raise a MisConfigurationError is the provided names aren't in _Preprocess_funcs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -147,8 +144,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:

for stage in _STAGES_PREFIX.values():

for _ in range(10):
getattr(dm, f"show_{stage}_batch")(reset=False)
for fcn_name in _PREPROCESS_FUNCS:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the previous one, it was asserting we could iterate more than the dataset length and the iterator was being reset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And let's add yours too.

@@ -521,3 +558,68 @@ def from_filepaths(
seed=seed,
**kwargs
)


class _MatplotlibVisualization(BaseVisualization):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this class public.

def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return _MatplotlibVisualization(*args, **kwargs)

def show(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's find another way to do that. This is confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have a better naming here. An alternative but a bit more hacky, is to pass blocking flag across all the different functions until the matplotlib plt.show calls.

rows: int = num_samples // cols

# create figure and set title
fig, axs = plt.subplots(rows, cols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise an exception if matplotlib isn't available here: _MATPLOTLIB_AVAILABLE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, check if that's appropriate

# filter out the functions to visualise
func_name: str = self._fcn_white_list[running_stage]
func_names_list: Set[str] = list(set([func_name]) & set(_PREPROCESS_FUNCS))
if len(func_names_list) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this check the show_{}_batches in the DataModule.

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! small comments 😃

flash/data/base_viz.py Outdated Show resolved Hide resolved
flash/data/base_viz.py Outdated Show resolved Hide resolved
flash/data/base_viz.py Outdated Show resolved Hide resolved
flash/data/data_module.py Outdated Show resolved Hide resolved
flash/vision/classification/data.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great ! Nicely done !

rows: int = num_samples // cols

if not _MATPLOTLIB_AVAILABLE:
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")
raise MisconfigurationException("You need matplotlib to visualise. Please, use `pip install matplotlib`")

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! LGTM 😃

@tchaton tchaton merged commit 1f9e151 into master Apr 22, 2021
@tchaton tchaton deleted the feat/viz_callback branch April 22, 2021 12:08
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants