From e76a61a2b1b8c733d91adea46a061513b85f6361 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 9 Aug 2023 11:02:42 +0200 Subject: [PATCH] init `target_formatter` not only typed (#1665) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/flash/audio/classification/input.py | 1 + src/flash/image/classification/input.py | 1 + src/flash/text/classification/input.py | 6 +++++- src/flash/video/classification/input.py | 2 ++ 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/flash/audio/classification/input.py b/src/flash/audio/classification/input.py index d0174fdb57..0ed30a155f 100644 --- a/src/flash/audio/classification/input.py +++ b/src/flash/audio/classification/input.py @@ -138,6 +138,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): diff --git a/src/flash/image/classification/input.py b/src/flash/image/classification/input.py index 7991f595a4..18a9c38096 100644 --- a/src/flash/image/classification/input.py +++ b/src/flash/image/classification/input.py @@ -157,6 +157,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): diff --git a/src/flash/text/classification/input.py b/src/flash/text/classification/input.py index 54aae52532..cb37a69326 100644 --- a/src/flash/text/classification/input.py +++ b/src/flash/text/classification/input.py @@ -53,7 +53,11 @@ def load_data( self.load_target_metadata(targets, target_formatter=target_formatter) # If we had binary multi-class targets then we also know the labels (column names) - if isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List): + if ( + hasattr(self, "target_formatter") + and isinstance(self.target_formatter, MultiBinaryTargetFormatter) + and isinstance(target_keys, List) + ): self.labels = target_keys # remove extra columns diff --git a/src/flash/video/classification/input.py b/src/flash/video/classification/input.py index 2d8b42ab54..cc6bdb01cc 100644 --- a/src/flash/video/classification/input.py +++ b/src/flash/video/classification/input.py @@ -215,6 +215,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): @@ -243,6 +244,7 @@ def load_data( # If we had binary multi-class targets then we also know the labels (column names) if ( self.training + and hasattr(self, "target_formatter") and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(targets, List) ):