-
Notifications
You must be signed in to change notification settings - Fork 211
Add visualisation callback for image classification #228
Conversation
) | ||
data_viz.show_train_batch() | ||
data_viz.show_val_batch() | ||
data_viz.show_test_batch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to decide how we validate this functionality during tests since it involves matplotlib visualization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use: https://matplotlib.org/stable/devel/testing.html
) | ||
data_viz.show_train_batch() | ||
data_viz.show_val_batch() | ||
data_viz.show_test_batch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use: https://matplotlib.org/stable/devel/testing.html
Hello @edgarriba! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-04-22 11:54:46 UTC |
176634c
to
dfdf333
Compare
dfdf333
to
0ef5d2a
Compare
Codecov Report
@@ Coverage Diff @@
## master #228 +/- ##
==========================================
+ Coverage 86.81% 86.90% +0.09%
==========================================
Files 58 58
Lines 2981 3055 +74
==========================================
+ Hits 2588 2655 +67
- Misses 393 400 +7
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM ! Small changes needed.
flash/data/data_module.py
Outdated
@@ -92,6 +92,9 @@ def __init__( | |||
# this may also trigger data preloading | |||
self.set_running_stages() | |||
|
|||
# buffer to store the functions to visualise | |||
self._fcn_white_list: Dict[str, Set[str]] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This variable name isn't clean. Mind finding a name related to viz.
flash/data/data_module.py
Outdated
data_fetcher._show(stage) | ||
if reset: | ||
self.viz.batches[stage] = {} | ||
|
||
def show_train_batch(self, reset: bool = True) -> None: | ||
def show_train_batch(self, name: str = 'load_sample', reset: bool = True) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List of hook names should also be supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can provide the following api:
def show_train_batch(self, name: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
...
"""This function is used to visualize a batch from the train dataloader.""" | ||
self._show_batch(_STAGES_PREFIX[RunningStage.TRAINING], reset=reset) | ||
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this logic to _show_batch
to reduce duplicated code and raise a MisConfigurationError is the provided names aren't in _Preprocess_funcs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tests/data/test_callbacks.py
Outdated
@@ -147,8 +144,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: | |||
|
|||
for stage in _STAGES_PREFIX.values(): | |||
|
|||
for _ in range(10): | |||
getattr(dm, f"show_{stage}_batch")(reset=False) | |||
for fcn_name in _PREPROCESS_FUNCS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep the previous one, it was asserting we could iterate more than the dataset length and the iterator was being reset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And let's add yours too.
flash/vision/classification/data.py
Outdated
@@ -521,3 +558,68 @@ def from_filepaths( | |||
seed=seed, | |||
**kwargs | |||
) | |||
|
|||
|
|||
class _MatplotlibVisualization(BaseVisualization): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this class public.
flash/vision/classification/data.py
Outdated
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: | ||
return _MatplotlibVisualization(*args, **kwargs) | ||
|
||
def show(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's find another way to do that. This is confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have a better naming here. An alternative but a bit more hacky, is to pass blocking
flag across all the different functions until the matplotlib plt.show
calls.
rows: int = num_samples // cols | ||
|
||
# create figure and set title | ||
fig, axs = plt.subplots(rows, cols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise an exception if matplotlib isn't available here: _MATPLOTLIB_AVAILABLE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, check if that's appropriate
flash/data/base_viz.py
Outdated
# filter out the functions to visualise | ||
func_name: str = self._fcn_white_list[running_stage] | ||
func_names_list: Set[str] = list(set([func_name]) & set(_PREPROCESS_FUNCS)) | ||
if len(func_names_list) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this check the show_{}_batches in the DataModule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! small comments 😃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great ! Nicely done !
rows: int = num_samples // cols | ||
|
||
if not _MATPLOTLIB_AVAILABLE: | ||
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib") | |
raise MisconfigurationException("You need matplotlib to visualise. Please, use `pip install matplotlib`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! LGTM 😃
What does this PR do?
This PR implements the feature request from #74 and implements
ImageClassificationDataVisualizer
to visualize the images and its associated labels right before feeding to the model.
Usage proposal:
Result obtained from the unittest:
TBD:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃