diff --git a/CHANGELOG.md b/CHANGELOG.md index 1abf240d9c..8505006e41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `DataSource` to `Input` ([#929](https://github.com/PyTorchLightning/lightning-flash/pull/929)) + - Changed `Preprocess` to `InputTransform` ([#951](https://github.com/PyTorchLightning/lightning-flash/pull/951)) - Changed classes named `*Serializer` and properties / variables named `serializer` to be `*Output` and `output` respectively ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) @@ -30,6 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `flash.text.seq2seq.core.metrics` in favour of `torchmetrics[text]` ([#648](https://github.com/PyTorchLightning/lightning-flash/pull/648)) +- Deprecated `flash.core.data.data_source.DefaultDataKeys` in favour of `flash.DataKeys` ([#929](https://github.com/PyTorchLightning/lightning-flash/pull/929)) + +- Deprecated `data_source` argument to `flash.Task.predict` in favour of `input` ([#929](https://github.com/PyTorchLightning/lightning-flash/pull/929)) + ### Fixed ### Removed diff --git a/docs/extensions/autodatasources.py b/docs/extensions/autodatasources.py index b1bc85a1ff..98fce308ae 100644 --- a/docs/extensions/autodatasources.py +++ b/docs/extensions/autodatasources.py @@ -61,17 +61,14 @@ def _resolve_transforms(_): return None input_transform = PatchedInputTransform() - data_sources = { - data_source: input_transform.data_source_of_name(data_source) - for data_source in input_transform.available_data_sources() - } + inputs = {input: input_transform.input_of_name(input) for input in input_transform.available_inputs()} ENVIRONMENT.get_template("base.rst") rendered_content = ENVIRONMENT.get_template(data_module_name).render( data_module=f":class:`~{data_module_path}.{data_module_name}`", data_module_raw=data_module_name, - data_sources=data_sources, + inputs=inputs, ) node = nodes.section() diff --git a/docs/extensions/templates/base.rst b/docs/extensions/templates/base.rst index 6c5a54468a..93fd50b040 100644 --- a/docs/extensions/templates/base.rst +++ b/docs/extensions/templates/base.rst @@ -7,15 +7,15 @@ This section details the available ways to load your own data into the {{ data_module }}. -{% if 'folders' in data_sources %} +{% if 'folders' in inputs %} {% call render_subsection('from_folders') %} {% block from_folders %} Construct the {{ data_module }} from folders. -{% if data_sources['folders'].extensions is defined %} -The supported file extensions are: {{ data_sources['folders'].extensions|join(', ') }}. -{% set extension = data_sources['folders'].extensions[0] %} +{% if inputs['folders'].extensions is defined %} +The supported file extensions are: {{ inputs['folders'].extensions|join(', ') }}. +{% set extension = inputs['folders'].extensions[0] %} {% else %} {% set extension = '' %} {% endif %} @@ -54,15 +54,15 @@ Example:: {% endblock %} {% endcall %} {% endif %} -{% if 'files' in data_sources %} +{% if 'files' in inputs %} {% call render_subsection('from_files') %} {% block from_files %} Construct the {{ data_module }} from lists of files and corresponding lists of targets. -{% if data_sources['files'].extensions is defined %} -The supported file extensions are: {{ data_sources['files'].extensions|join(', ') }}. -{% set extension = data_sources['files'].extensions[0] %} +{% if inputs['files'].extensions is defined %} +The supported file extensions are: {{ inputs['files'].extensions|join(', ') }}. +{% set extension = inputs['files'].extensions[0] %} {% else %} {% set extension = '' %} {% endif %} @@ -80,7 +80,7 @@ Example:: {% endblock %} {% endcall %} {% endif %} -{% if 'datasets' in data_sources %} +{% if 'datasets' in inputs %} {% call render_subsection('from_datasets') %} {% block from_datasets %} diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index 16cdf31d88..a3ae290b63 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -34,10 +34,10 @@ __________________ speech_recognition.data.SpeechRecognitionInputTransform speech_recognition.data.SpeechRecognitionBackboneState speech_recognition.data.SpeechRecognitionOutputTransform - speech_recognition.data.SpeechRecognitionCSVDataSource - speech_recognition.data.SpeechRecognitionJSONDataSource + speech_recognition.data.SpeechRecognitionCSVInput + speech_recognition.data.SpeechRecognitionJSONInput speech_recognition.data.BaseSpeechRecognition - speech_recognition.data.SpeechRecognitionFileDataSource - speech_recognition.data.SpeechRecognitionPathsDataSource - speech_recognition.data.SpeechRecognitionDatasetDataSource + speech_recognition.data.SpeechRecognitionFileInput + speech_recognition.data.SpeechRecognitionPathsInput + speech_recognition.data.SpeechRecognitionDatasetInput speech_recognition.data.SpeechRecognitionDeserializer diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 0d56e52cdd..36c3da1201 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -81,7 +81,7 @@ _____________________________ ~flash.core.data.data_pipeline.DataPipeline ~flash.core.data.data_pipeline.DataPipelineState -flash.core.data.data_source +flash.core.data.io.input ___________________________ .. autosummary:: @@ -89,26 +89,26 @@ ___________________________ :nosignatures: :template: classtemplate.rst - ~flash.core.data.data_source.DatasetDataSource - ~flash.core.data.data_source.DataSource - ~flash.core.data.data_source.DefaultDataKeys - ~flash.core.data.data_source.DefaultDataSources - ~flash.core.data.data_source.FiftyOneDataSource - ~flash.core.data.data_source.ImageLabelsMap - ~flash.core.data.data_source.LabelsState - ~flash.core.data.data_source.MockDataset - ~flash.core.data.data_source.NumpyDataSource - ~flash.core.data.data_source.PathsDataSource - ~flash.core.data.data_source.SequenceDataSource - ~flash.core.data.data_source.TensorDataSource + ~flash.core.data.io.input.DatasetInput + ~flash.core.data.io.input.Input + ~flash.core.data.io.input.DataKeys + ~flash.core.data.io.input.InputFormat + ~flash.core.data.io.input.FiftyOneInput + ~flash.core.data.io.input.ImageLabelsMap + ~flash.core.data.io.input.LabelsState + ~flash.core.data.io.input.MockDataset + ~flash.core.data.io.input.NumpyInput + ~flash.core.data.io.input.PathsInput + ~flash.core.data.io.input.SequenceInput + ~flash.core.data.io.input.TensorInput .. autosummary:: :toctree: generated/ :nosignatures: - ~flash.core.data.data_source.has_file_allowed_extension - ~flash.core.data.data_source.has_len - ~flash.core.data.data_source.make_dataset + ~flash.core.data.io.input.has_file_allowed_extension + ~flash.core.data.io.input.has_len + ~flash.core.data.io.input.make_dataset flash.core.data.process _______________________ diff --git a/docs/source/api/flash.rst b/docs/source/api/flash.rst index c83d6fb5f4..cc8775572e 100644 --- a/docs/source/api/flash.rst +++ b/docs/source/api/flash.rst @@ -7,7 +7,7 @@ flash :nosignatures: :template: classtemplate.rst - ~flash.core.data.data_source.DataSource + ~flash.core.data.io.input.Input ~flash.core.data.data_module.DataModule ~flash.core.data.callback.FlashCallback ~flash.core.data.io.output_transform.OutputTransform diff --git a/docs/source/api/graph.rst b/docs/source/api/graph.rst index 65a437cdf4..2abb40bf76 100644 --- a/docs/source/api/graph.rst +++ b/docs/source/api/graph.rst @@ -40,4 +40,4 @@ ________________ :nosignatures: :template: classtemplate.rst - ~data.GraphDatasetDataSource + ~data.GraphDatasetInput diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst index 7d1280a3f3..b8db9c6ad9 100644 --- a/docs/source/api/image.rst +++ b/docs/source/api/image.rst @@ -43,7 +43,7 @@ ________________ ~detection.data.ObjectDetectionData detection.data.FiftyOneParser - detection.data.ObjectDetectionFiftyOneDataSource + detection.data.ObjectDetectionFiftyOneInput detection.output.FiftyOneDetectionLabels detection.data.ObjectDetectionInputTransform @@ -96,10 +96,10 @@ ____________ ~segmentation.data.SemanticSegmentationInputTransform segmentation.data.SegmentationMatplotlibVisualization - segmentation.data.SemanticSegmentationNumpyDataSource - segmentation.data.SemanticSegmentationTensorDataSource - segmentation.data.SemanticSegmentationPathsDataSource - segmentation.data.SemanticSegmentationFiftyOneDataSource + segmentation.data.SemanticSegmentationNumpyInput + segmentation.data.SemanticSegmentationTensorInput + segmentation.data.SemanticSegmentationPathsInput + segmentation.data.SemanticSegmentationFiftyOneInput segmentation.data.SemanticSegmentationDeserializer segmentation.model.SemanticSegmentationOutputTransform segmentation.output.FiftyOneSegmentationLabels @@ -140,7 +140,7 @@ ________________ :template: classtemplate.rst ~data.ImageDeserializer - ~data.ImageFiftyOneDataSource - ~data.ImageNumpyDataSource - ~data.ImagePathsDataSource - ~data.ImageTensorDataSource + ~data.ImageFiftyOneInput + ~data.ImageNumpyInput + ~data.ImagePathsInput + ~data.ImageTensorInput diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst index dc4b777423..298476b291 100644 --- a/docs/source/api/pointcloud.rst +++ b/docs/source/api/pointcloud.rst @@ -21,8 +21,8 @@ ____________ ~segmentation.data.PointCloudSegmentationData segmentation.data.PointCloudSegmentationInputTransform - segmentation.data.PointCloudSegmentationFoldersDataSource - segmentation.data.PointCloudSegmentationDatasetDataSource + segmentation.data.PointCloudSegmentationFoldersInput + segmentation.data.PointCloudSegmentationDatasetInput Object Detection ________________ @@ -36,5 +36,5 @@ ________________ ~detection.data.PointCloudObjectDetectorData detection.data.PointCloudObjectDetectorInputTransform - detection.data.PointCloudObjectDetectorFoldersDataSource - detection.data.PointCloudObjectDetectorDatasetDataSource + detection.data.PointCloudObjectDetectorFoldersInput + detection.data.PointCloudObjectDetectorDatasetInput diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index a258890495..7defbb9871 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -43,7 +43,7 @@ ___________ ~forecasting.data.TabularForecastingData forecasting.data.TabularForecastingInputTransform - forecasting.data.TabularForecastingDataFrameDataSource + forecasting.data.TabularForecastingDataFrameInput forecasting.data.TimeSeriesDataSetParametersState flash.tabular.data @@ -55,8 +55,8 @@ __________________ :template: classtemplate.rst ~data.TabularData - ~data.TabularDataFrameDataSource - ~data.TabularCSVDataSource + ~data.TabularDataFrameInput + ~data.TabularCSVInput ~data.TabularDeserializer ~data.TabularOutputTransform ~data.TabularInputTransform diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index d692994aa8..deac52117b 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -23,13 +23,13 @@ ______________ classification.data.TextClassificationOutputTransform classification.data.TextClassificationInputTransform classification.data.TextDeserializer - classification.data.TextDataSource - classification.data.TextCSVDataSource - classification.data.TextJSONDataSource - classification.data.TextDataFrameDataSource - classification.data.TextParquetDataSource - classification.data.TextHuggingFaceDatasetDataSource - classification.data.TextListDataSource + classification.data.TextInput + classification.data.TextCSVInput + classification.data.TextJSONInput + classification.data.TextDataFrameInput + classification.data.TextParquetInput + classification.data.TextHuggingFaceDatasetInput + classification.data.TextListInput Question Answering __________________ @@ -43,15 +43,14 @@ __________________ ~question_answering.data.QuestionAnsweringData question_answering.data.QuestionAnsweringBackboneState - question_answering.data.QuestionAnsweringCSVDataSource - question_answering.data.QuestionAnsweringDataSource - question_answering.data.QuestionAnsweringDictionaryDataSource - question_answering.data.QuestionAnsweringFileDataSource - question_answering.data.QuestionAnsweringJSONDataSource + question_answering.data.QuestionAnsweringCSVInput + question_answering.data.QuestionAnsweringInput + question_answering.data.QuestionAnsweringDictionaryInput + question_answering.data.QuestionAnsweringFileInput + question_answering.data.QuestionAnsweringJSONInput question_answering.data.QuestionAnsweringOutputTransform question_answering.data.QuestionAnsweringInputTransform - question_answering.data.SQuADDataSource - + question_answering.data.SQuADInput Summarization _____________ @@ -92,10 +91,11 @@ _______________ ~seq2seq.core.finetuning.Seq2SeqFreezeEmbeddings seq2seq.core.data.Seq2SeqBackboneState - seq2seq.core.data.Seq2SeqCSVDataSource - seq2seq.core.data.Seq2SeqDataSource - seq2seq.core.data.Seq2SeqFileDataSource - seq2seq.core.data.Seq2SeqJSONDataSource + seq2seq.core.data.Seq2SeqCSVInput + seq2seq.core.data.Seq2SeqInput + seq2seq.core.data.Seq2SeqFileInput + seq2seq.core.data.Seq2SeqJSONInput seq2seq.core.data.Seq2SeqOutputTransform seq2seq.core.data.Seq2SeqInputTransform - seq2seq.core.data.Seq2SeqSentencesDataSource + seq2seq.core.data.Seq2SeqSentencesInput + seq2seq.core.metrics.BLEUScore diff --git a/docs/source/api/video.rst b/docs/source/api/video.rst index f9825041f0..c5eb46210c 100644 --- a/docs/source/api/video.rst +++ b/docs/source/api/video.rst @@ -21,7 +21,7 @@ ______________ ~classification.data.VideoClassificationData classification.data.BaseVideoClassification - classification.data.VideoClassificationFiftyOneDataSource - classification.data.VideoClassificationPathsDataSource + classification.data.VideoClassificationFiftyOneInput + classification.data.VideoClassificationPathsInput classification.data.VideoClassificationInputTransform classification.model.VideoClassifierFinetuning diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index 5dc0e71387..2022ea6314 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -26,9 +26,9 @@ Here are common terms you need to be familiar with: * - :class:`~flash.core.data.data_module.DataModule` - The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders. * - :class:`~flash.core.data.data_pipeline.DataPipeline` - - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` objects. - * - :class:`~flash.core.data.data_source.DataSource` - - The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names). + - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.io.input.Input`, :class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` objects. + * - :class:`~flash.core.data.io.input.Input` + - The :class:`~flash.core.data.io.input.Input` provides :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample` hooks for creating data sets from metadata (such as folder names). * - :class:`~flash.core.data.io.input_transform.InputTransform` - The :class:`~flash.core.data.io.input_transform.InputTransform` provides a simple hook-based API to encapsulate your pre-processing logic. These hooks (such as :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). @@ -57,7 +57,7 @@ and provide it to a :class:`torch.utils.data.DataLoader`. However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment. Usually, extra processing logic should be added to bridge the gap between training data and raw data. -The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. +The :class:`~flash.core.data.io.input.Input` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. The :class:`~flash.core.data.io.input_transform.InputTransform` and :class:`~flash.core.data.io.output_transform.OutputTransform` classes can be used to manage the input and output transforms. The :class:`~flash.core.data.io.output.Output` class provides the logic for converting :class:`~flash.core.data.io.output_transform.OutputTransform` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). @@ -94,8 +94,8 @@ Any Flash :class:`~flash.core.data.data_module.DataModule` can be created direct The :class:`~flash.core.data.data_module.DataModule` provides additional ``classmethod`` helpers (``from_*``) for loading data from various sources. -In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule` internally retrieves the correct :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.io.input_transform.InputTransform`. -Flash :class:`~flash.core.data.auto_dataset.AutoDataset` instances are created from the :class:`~flash.core.data.data_source.DataSource` for train, val, test, and predict. +In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule` internally retrieves the correct :class:`~flash.core.data.io.input.Input` to use from the :class:`~flash.core.data.io.input_transform.InputTransform`. +Flash :class:`~flash.core.data.auto_dataset.AutoDataset` instances are created from the :class:`~flash.core.data.io.input.Input` for train, val, test, and predict. The :class:`~flash.core.data.data_module.DataModule` populates the ``DataLoader`` for each stage with the corresponding :class:`~flash.core.data.auto_dataset.AutoDataset`. ************************************** @@ -149,7 +149,7 @@ Alternatively, the user may directly override the hooks for their needs like thi Create your own InputTransform and DataModule ********************************************* -The example below shows a very simple ``ImageClassificationInputTransform`` with a single ``ImageClassificationFoldersDataSource`` and an ``ImageClassificationDataModule``. +The example below shows a very simple ``ImageClassificationInputTransform`` with a single ``ImageClassificationFoldersInput`` and an ``ImageClassificationDataModule``. 1. User-Facing API design _________________________ @@ -180,23 +180,23 @@ Example:: trainer.fit(model, dm) -2. The DataSource +2. The Input _________________ -We start by implementing the ``ImageClassificationFoldersDataSource``. +We start by implementing the ``ImageClassificationFoldersInput``. The ``load_data`` method will produce a list of files and targets from the given directory. The ``load_sample`` method will load the given file as a ``PIL.Image``. -Here's the full ``ImageClassificationFoldersDataSource``: +Here's the full ``ImageClassificationFoldersInput``: .. code-block:: python from PIL import Image from torchvision.datasets.folder import make_dataset from typing import Any, Dict - from flash.core.data.data_source import DataSource, DefaultDataKeys + from flash.core.data.io.input import Input, DataKeys - class ImageClassificationFoldersDataSource(DataSource): + class ImageClassificationFoldersInput(Input): def load_data(self, folder: str, dataset: Any) -> Iterable: # The dataset is optional but can be useful to save some metadata. @@ -211,21 +211,21 @@ Here's the full ``ImageClassificationFoldersDataSource``: return [ { - DefaultDataKeys.INPUT: file, - DefaultDataKeys.TARGET: target, + DataKeys.INPUT: file, + DataKeys.TARGET: target, } for file, target in metadata ] def predict_load_data(self, predict_folder: str) -> Iterable: # This returns [image_path_1, ... image_path_m]. - return [{DefaultDataKeys.INPUT: file} for file in os.listdir(folder)] + return [{DataKeys.INPUT: file} for file in os.listdir(folder)] def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = Image.open(sample[DefaultDataKeys.INPUT]) + sample[DataKeys.INPUT] = Image.open(sample[DataKeys.INPUT]) return sample -.. note:: We return samples as dictionaries using the :class:`~flash.core.data.data_source.DefaultDataKeys` by convention. This is the recommended (although not required) way to represent data in Flash. +.. note:: We return samples as dictionaries using the :class:`~flash.core.data.io.input.DataKeys` by convention. This is the recommended (although not required) way to represent data in Flash. 3. The InputTransform _____________________ @@ -235,7 +235,7 @@ Next, implement your custom ``ImageClassificationInputTransform`` with some defa .. code-block:: python from typing import Any, Callable, Dict, Optional - from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources + from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_transform import InputTransform import torchvision.transforms.functional as T @@ -253,10 +253,10 @@ Next, implement your custom ``ImageClassificationInputTransform`` with some defa val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FOLDERS: ImageClassificationFoldersDataSource(), + inputs={ + InputFormat.FOLDERS: ImageClassificationFoldersInput(), }, - default_data_source=DefaultDataSources.FOLDERS, + default_input=InputFormat.FOLDERS, ) def get_state_dict(self) -> Dict[str, Any]: @@ -267,13 +267,13 @@ Next, implement your custom ``ImageClassificationInputTransform`` with some defa return cls(**state_dict) def default_transforms(self) -> Dict[str, Callable]: - return {"to_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.to_tensor)} + return {"to_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.to_tensor)} 4. The DataModule _________________ Finally, let's implement the ``ImageClassificationDataModule``. -We get the ``from_folders`` classmethod for free as we've registered a ``DefaultDataSources.FOLDERS`` data source in our ``ImageClassificationInputTransform``. +We get the ``from_folders`` classmethod for free as we've registered a ``InputFormat.FOLDERS`` data source in our ``ImageClassificationInputTransform``. All we need to do is attach our :class:`~flash.core.data.io.input_transform.InputTransform` class like this: .. code-block:: python @@ -291,12 +291,12 @@ All we need to do is attach our :class:`~flash.core.data.io.input_transform.Inpu How it works behind the scenes ****************************** -DataSource -__________ +Input +_____ .. note:: - The :meth:`~flash.core.data.data_source.DataSource.load_data` and - :meth:`~flash.core.data.data_source.DataSource.load_sample` will be used to generate an + The :meth:`~flash.core.data.io.input.Input.load_data` and + :meth:`~flash.core.data.io.input.Input.load_sample` will be used to generate an :class:`~flash.core.data.auto_dataset.AutoDataset` object. Here is the :class:`~flash.core.data.auto_dataset.AutoDataset` pseudo-code. @@ -306,16 +306,16 @@ Here is the :class:`~flash.core.data.auto_dataset.AutoDataset` pseudo-code. class AutoDataset: def __init__( self, - data: List[Any], # output of `DataSource.load_data` - data_source: DataSource, + data: List[Any], # output of `Input.load_data` + input: Input, running_stage: RunningStage, ): self.data = data - self.data_source = data_source + self.input = input def __getitem__(self, index: int): - return self.data_source.load_sample(self.data[index]) + return self.input.load_sample(self.data[index]) def __len__(self): return len(self.data) diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index fc10f2b713..0ef504993a 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -120,13 +120,13 @@ We use the `post_tensor_transform` hook to apply the transformations after the i from torchvision import transforms as T import flash - from flash.core.data.data_source import DefaultDataKeys + from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, merge_transforms from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.transforms import default_transforms post_tensor_transform = ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]), ) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index c88c67c8f8..235d9af383 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -81,9 +81,9 @@ Loading Data {% block from_folders %} Construct the {{ data_module }} from folders. - {% if data_sources['folders'].extensions is defined %} - The supported file extensions are: {{ data_sources['folders'].extensions|join(', ') }}. - {% set extension = data_sources['folders'].extensions[0] %} + {% if inputs['folders'].extensions is defined %} + The supported file extensions are: {{ inputs['folders'].extensions|join(', ') }}. + {% set extension = inputs['folders'].extensions[0] %} {% else %} {% set extension = '' %} {% endif %} @@ -124,9 +124,9 @@ Loading Data {% block from_files %} Construct the {{ data_module }} from lists of input images and corresponding list of target images. - {% if data_sources['files'].extensions is defined %} - The supported file extensions are: {{ data_sources['files'].extensions|join(', ') }}. - {% set extension = data_sources['files'].extensions[0] %} + {% if inputs['files'].extensions is defined %} + The supported file extensions are: {{ inputs['files'].extensions|join(', ') }}. + {% set extension = inputs['files'].extensions[0] %} {% else %} {% set extension = '' %} {% endif %} diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index d5eb6b03e6..5ce997842f 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -7,79 +7,79 @@ The Data The first step to contributing a task is to implement the classes we need to load some data. Inside `data.py `_ you should implement: -#. some :class:`~flash.core.data.data_source.DataSource` classes *(optional)* +#. some :class:`~flash.core.data.io.input.Input` classes *(optional)* #. a :class:`~flash.core.data.io.input_transform.InputTransform` #. a :class:`~flash.core.data.data_module.DataModule` #. a :class:`~flash.core.data.base_viz.BaseVisualization` *(optional)* #. a :class:`~flash.core.data.io.output_transform.OutputTransform` *(optional)* -DataSource -^^^^^^^^^^ +Input +^^^^^ -The :class:`~flash.core.data.data_source.DataSource` class contains the logic for data loading from different sources such as folders, files, tensors, etc. +The :class:`~flash.core.data.io.input.Input` class contains the logic for data loading from different sources such as folders, files, tensors, etc. Every Flash :class:`~flash.core.data.data_module.DataModule` can be instantiated with :meth:`~flash.core.data.data_module.DataModule.from_datasets`. -For each additional way you want the user to be able to instantiate your :class:`~flash.core.data.data_module.DataModule`, you'll need to create a :class:`~flash.core.data.data_source.DataSource`. -Each :class:`~flash.core.data.data_source.DataSource` has 2 methods: +For each additional way you want the user to be able to instantiate your :class:`~flash.core.data.data_module.DataModule`, you'll need to create a :class:`~flash.core.data.io.input.Input`. +Each :class:`~flash.core.data.io.input.Input` has 2 methods: -- :meth:`~flash.core.data.data_source.DataSource.load_data` takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata. -- :meth:`~flash.core.data.data_source.DataSource.load_sample` then takes as input a single element from the output of ``load_data`` and returns a sample. +- :meth:`~flash.core.data.io.input.Input.load_data` takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata. +- :meth:`~flash.core.data.io.input.Input.load_sample` then takes as input a single element from the output of ``load_data`` and returns a sample. -By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.data_source.DataSource.load_data` and a :meth:`~flash.core.data.data_source.DataSource.load_sample` to create a :class:`~flash.core.data.data_source.DataSource`. -Where possible, you should override one of our existing :class:`~flash.core.data.data_source.DataSource` classes. +By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.io.input.Input.load_data` and a :meth:`~flash.core.data.io.input.Input.load_sample` to create a :class:`~flash.core.data.io.input.Input`. +Where possible, you should override one of our existing :class:`~flash.core.data.io.input.Input` classes. -Let's start by implementing a ``TemplateNumpyDataSource``, which overrides :class:`~flash.core.data.data_source.NumpyDataSource`. -The main :class:`~flash.core.data.data_source.DataSource` method that we have to implement is :meth:`~flash.core.data.data_source.DataSource.load_data`. -As we're extending the ``NumpyDataSource``, we expect the same ``data`` argument (in this case, a tuple containing data and corresponding target arrays). +Let's start by implementing a ``TemplateNumpyInput``, which overrides :class:`~flash.core.data.io.input.NumpyInput`. +The main :class:`~flash.core.data.io.input.Input` method that we have to implement is :meth:`~flash.core.data.io.input.Input.load_data`. +As we're extending the ``NumpyInput``, we expect the same ``data`` argument (in this case, a tuple containing data and corresponding target arrays). We can also take the dataset argument. -Any attributes we set on ``dataset`` will be available on the :class:`~torch.utils.data.Dataset` generated by our :class:`~flash.core.data.data_source.DataSource`. +Any attributes we set on ``dataset`` will be available on the :class:`~torch.utils.data.Dataset` generated by our :class:`~flash.core.data.io.input.Input`. In this data source, we'll set the ``num_features`` attribute. -Here's the code for our ``TemplateNumpyDataSource.load_data`` method: +Here's the code for our ``TemplateNumpyInput.load_data`` method: .. literalinclude:: ../../../flash/template/classification/data.py :language: python :dedent: 4 - :pyobject: TemplateNumpyDataSource.load_data + :pyobject: TemplateNumpyInput.load_data .. note:: Later, when we add :ref:`our DataModule implementation `, we'll make ``num_features`` available to the user. Sometimes you need to something a bit more custom. -When creating a custom :class:`~flash.core.data.data_source.DataSource`, the type of the ``data`` argument is up to you. +When creating a custom :class:`~flash.core.data.io.input.Input`, the type of the ``data`` argument is up to you. For our template :class:`~flash.core.data.model.Task`, it would be cool if the user could provide a scikit-learn ``Bunch`` as the data source. -To achieve this, we'll add a ``TemplateSKLearnDataSource`` whose ``load_data`` expects a ``Bunch`` as input. -We override our ``TemplateNumpyDataSource`` so that we can call ``super`` with the data and targets extracted from the ``Bunch``. +To achieve this, we'll add a ``TemplateSKLearnInput`` whose ``load_data`` expects a ``Bunch`` as input. +We override our ``TemplateNumpyInput`` so that we can call ``super`` with the data and targets extracted from the ``Bunch``. We perform two additional steps here to improve the user experience: 1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`. -2. We create and set a :class:`~flash.core.data.data_source.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` serializer, so the user doesn't need to provide them. +2. We create and set a :class:`~flash.core.data.io.input.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` serializer, so the user doesn't need to provide them. -Here's the code for the ``TemplateSKLearnDataSource.load_data`` method: +Here's the code for the ``TemplateSKLearnInput.load_data`` method: .. literalinclude:: ../../../flash/template/classification/data.py :language: python :dedent: 4 - :pyobject: TemplateSKLearnDataSource.load_data + :pyobject: TemplateSKLearnInput.load_data -We can customize the behaviour of our :meth:`~flash.core.data.data_source.DataSource.load_data` for different stages, by prepending `train`, `val`, `test`, or `predict`. -For our ``TemplateSKLearnDataSource``, we don't want to provide any targets to the model when predicting. +We can customize the behaviour of our :meth:`~flash.core.data.io.input.Input.load_data` for different stages, by prepending `train`, `val`, `test`, or `predict`. +For our ``TemplateSKLearnInput``, we don't want to provide any targets to the model when predicting. We can implement ``predict_load_data`` like this: .. literalinclude:: ../../../flash/template/classification/data.py :language: python :dedent: 4 - :pyobject: TemplateSKLearnDataSource.predict_load_data + :pyobject: TemplateSKLearnInput.predict_load_data -DataSource vs Dataset -~~~~~~~~~~~~~~~~~~~~~ +Input vs Dataset +~~~~~~~~~~~~~~~~ -A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. -When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.io.input_transform.InputTransform`. -A :class:`~torch.utils.data.Dataset` is then created from the :class:`~flash.core.data.data_source.DataSource` for each stage (`train`, `val`, `test`, `predict`) using the provided metadata (e.g. folder name, numpy array etc.). +A :class:`~flash.core.data.io.input.Input` is not the same as a :class:`torch.utils.data.Dataset`. +When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.io.input.Input` to use from the :class:`~flash.core.data.io.input_transform.InputTransform`. +A :class:`~torch.utils.data.Dataset` is then created from the :class:`~flash.core.data.io.input.Input`` for each stage (`train`, `val`, `test`, `predict`) using the provided metadata (e.g. folder name, numpy array etc.). -The output of the :meth:`~flash.core.data.data_source.DataSource.load_data` can just be a :class:`torch.utils.data.Dataset` instance. -If the library that your :class:`~flash.core.data.model.Task` is based on provides a custom dataset, you don't need to re-write it as a :class:`~flash.core.data.data_source.DataSource`. -For example, the :meth:`~flash.core.data.data_source.DataSource.load_data` of the ``VideoClassificationPathsDataSource`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder. +The output of the :meth:`~flash.core.data.io.input.Input.load_data` can just be a :class:`torch.utils.data.Dataset` instance. +If the library that your :class:`~flash.core.data.model.Task` is based on provides a custom dataset, you don't need to re-write it as a :class:`~flash.core.data.io.input.Input`. +For example, the :meth:`~flash.core.data.io.input.Input.load_data` of the ``VideoClassificationPathsInput`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder. Here's how it looks (from `video/classification.data.py `_): .. literalinclude:: ../../../flash/video/classification/data.py @@ -101,12 +101,12 @@ Any additional arguments are up to you. Inside the ``__init__``, we make a call to super. This is where we register our data sources. Data sources should be given as a dictionary which maps data source name to data source object. -The name can be anything, but if you want to take advantage of our built-in ``from_*`` classmethods, you should use :class:`~flash.core.data.data_source.DefaultDataSources` as the names. -In our case, we have both a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call `"sklearn"`). +The name can be anything, but if you want to take advantage of our built-in ``from_*`` classmethods, you should use :class:`~flash.core.data.io.input.InputFormat` as the names. +In our case, we have both a :attr:`~flash.core.data.io.input.InputFormat.NUMPY` and a custom scikit-learn data source (which we'll call `"sklearn"`). -You should also provide a ``default_data_source``. +You should also provide a ``default_``. This is the name of the data source to use by default when predicting. -It'd be cool if we could get predictions just from a numpy array, so we'll use :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` as the default. +It'd be cool if we could get predictions just from a numpy array, so we'll use :attr:`~flash.core.data.io.input.InputFormat.NUMPY` as the default. Here's our ``TemplateInputTransform.__init__``: @@ -123,7 +123,7 @@ Let's first define the transform as a ``staticmethod``: :dedent: 4 :pyobject: TemplateInputTransform.input_to_tensor -Our inputs samples will be dictionaries whose keys are in the :class:`~flash.core.data.data_source.DefaultDataKeys`. +Our inputs samples will be dictionaries whose keys are in the :class:`~flash.core.data.io.input.DataKeys`. You can map each key to different transforms using :class:`~flash.core.data.transforms.ApplyToKeys`. Here's our ``default_transforms`` method: @@ -140,10 +140,10 @@ DataModule The :class:`~flash.core.data.data_module.DataModule` is responsible for creating the :class:`~torch.utils.data.DataLoader` and injecting the transforms for each stage. When the user calls a ``from_*`` method (such as :meth:`~flash.core.data.data_module.DataModule.from_numpy`), the following steps take place: -#. The :meth:`~flash.core.data.data_module.DataModule.from_data_source` method is called with the name of the :class:`~flash.core.data.data_source.DataSource` to use and the inputs to provide to :meth:`~flash.core.data.data_source.DataSource.load_data` for each stage. +#. The :meth:`~flash.core.data.data_module.DataModule.from_` method is called with the name of the :class:`~flash.core.data.io.input.Input` to use and the inputs to provide to :meth:`~flash.core.data.io.input.Input.load_data` for each stage. #. The :class:`~flash.core.data.io.input_transform.InputTransform` is created from ``cls.input_transform_cls`` (if it wasn't provided by the user) with any provided transforms. -#. The :class:`~flash.core.data.data_source.DataSource` of the provided name is retrieved from the :class:`~flash.core.data.io.input_transform.InputTransform`. -#. A :class:`~flash.core.data.auto_dataset.BaseAutoDataset` is created from the :class:`~flash.core.data.data_source.DataSource` for each stage. +#. The :class:`~flash.core.data.io.input.Input` of the provided name is retrieved from the :class:`~flash.core.data.io.input_transform.InputTransform`. +#. A :class:`~flash.core.data.auto_dataset.BaseAutoDataset` is created from the :class:`~flash.core.data.io.input.Input` for each stage. #. The :class:`~flash.core.data.data_module.DataModule` is instantiated with the data sets. | @@ -154,9 +154,9 @@ To create our ``TemplateData`` :class:`~flash.core.data.data_module.DataModule`, input_transform_cls = TemplateInputTransform -Since we provided a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` :class:`~flash.core.data.data_source.DataSource` in the ``TemplateInputTransform``, :meth:`~flash.core.data.data_module.DataModule.from_numpy` will now work with our ``TemplateData``. +Since we provided a :attr:`~flash.core.data.io.input.InputFormat.NUMPY` :class:`~flash.core.data.io.input.Input` in the ``TemplateInputTransform``, :meth:`~flash.core.data.data_module.DataModule.from_numpy` will now work with our ``TemplateData``. -If you've defined a fully custom :class:`~flash.core.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need to write a ``from_*`` method for each. +If you've defined a fully custom :class:`~flash.core.data.io.input.Input` (like our ``TemplateSKLearnInput``), then you will need to write a ``from_*`` method for each. Here's the ``from_sklearn`` method for our ``TemplateData``: .. literalinclude:: ../../../flash/template/classification/data.py @@ -207,18 +207,18 @@ As an example, here's the :class:`~text.classification.data.TextClassificationOu :language: python :pyobject: TextClassificationOutputTransform -In your :class:`~flash.core.data.data_source.DataSource` or :class:`~flash.core.data.io.input_transform.InputTransform`, you can add metadata to the batch using the :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA` key. +In your :class:`~flash.core.data.io.input.Input` or :class:`~flash.core.data.io.input_transform.InputTransform`, you can add metadata to the batch using the :attr:`~flash.core.data.io.input.DataKeys.METADATA` key. Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms. You should use this approach if your postprocessing depends on the state of the input before the :class:`~flash.core.data.io.input_transform.InputTransform` transforms. -For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA`. -Here's an example from the :class:`~flash.image.segmentation.SemanticSegmentationNumpyDataSource`: +For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the :attr:`~flash.core.data.io.input.DataKeys.METADATA`. +Here's an example from the :class:`~flash.image.segmentation.SemanticSegmentationNumpyInput`: .. literalinclude:: ../../../flash/image/segmentation/data.py :language: python :dedent: 4 - :pyobject: SemanticSegmentationNumpyDataSource.load_sample + :pyobject: SemanticSegmentationNumpyInput.load_sample -The :attr:`~flash.core.data.data_source.DefaultDataKeys.METADATA` can now be referenced in your :class:`~flash.core.data.io.output_transform.OutputTransform`. +The :attr:`~flash.core.data.io.input.DataKeys.METADATA` can now be referenced in your :class:`~flash.core.data.io.output_transform.OutputTransform`. For example, here's the code for the ``per_sample_transform`` method of the :class:`~flash.image.segmentation.model.SemanticSegmentationOutputTransform`: .. literalinclude:: ../../../flash/image/segmentation/model.py diff --git a/docs/source/template/tests.rst b/docs/source/template/tests.rst index 33d85952fb..b06397f99f 100644 --- a/docs/source/template/tests.rst +++ b/docs/source/template/tests.rst @@ -36,7 +36,7 @@ test_data.py The most important tests in `test_data.py `_ check that the ``from_*`` methods work correctly. In the class ``TestTemplateData``, we have two of these: ``test_from_numpy`` and ``test_from_sklearn``. -In general, there should be one ``test_from_*`` method for each :class:`~flash.core.data.data_source` you have configured. +In general, there should be one ``test_from_*`` method for each :class:`~flash.core.data.io.input` you have configured. Here's the code for ``test_from_numpy``: @@ -76,7 +76,7 @@ These tests are very similar to ``test_train``, but here they are for completene We also include tests for prediction named ``test_predict_*`` for each of our data sources. In our case, we have ``test_predict_numpy`` and ``test_predict_sklearn``. -These tests should use the ``data_source`` argument to :meth:`~flash.core.model.Task.predict` to select the required :class:`~flash.core.data.DataSource`. +These tests should use the ``input`` argument to :meth:`~flash.core.model.Task.predict` to select the required :class:`~flash.core.data.Input`. Here's ``test_predict_sklearn`` as an example: .. literalinclude:: ../../../tests/template/classification/test_model.py diff --git a/flash/__init__.py b/flash/__init__.py index 5b5414f9f3..620fca1d1c 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -20,15 +20,15 @@ if _TORCH_AVAILABLE: from flash.core.data.callback import FlashCallback - from flash.core.data.data_module import DataModule # noqa: E402 - from flash.core.data.data_source import DataSource + from flash.core.data.data_module import DataModule from flash.core.data.datasets import FlashDataset, FlashIterableDataset + from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Serializer - from flash.core.model import Task # noqa: E402 - from flash.core.trainer import Trainer # noqa: E402 + from flash.core.model import Task + from flash.core.trainer import Trainer _PACKAGE_ROOT = os.path.dirname(__file__) ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets") @@ -41,11 +41,12 @@ seed_everything(42) __all__ = [ - "DataSource", + "DataKeys", "DataModule", "FlashCallback", "FlashDataset", "FlashIterableDataset", + "Input", "InputTransform", "Output", "OutputTransform", diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index 4d6edb02bd..0b3c688f94 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -16,13 +16,13 @@ import numpy as np from flash.audio.classification.transforms import default_transforms, train_default_transforms -from flash.core.data.data_source import ( - DefaultDataKeys, - DefaultDataSources, +from flash.core.data.io.input import ( + DataKeys, has_file_allowed_extension, - LoaderDataFrameDataSource, - NumpyDataSource, - PathsDataSource, + InputFormat, + LoaderDataFrameInput, + NumpyInput, + PathsInput, ) from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer @@ -40,24 +40,24 @@ def spectrogram_loader(filepath: str): return data -class AudioClassificationNumpyDataSource(NumpyDataSource): +class AudioClassificationNumpyInput(NumpyInput): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = np.transpose(sample[DefaultDataKeys.INPUT], (1, 2, 0)) + sample[DataKeys.INPUT] = np.transpose(sample[DataKeys.INPUT], (1, 2, 0)) return sample -class AudioClassificationTensorDataSource(AudioClassificationNumpyDataSource): +class AudioClassificationTensorInput(AudioClassificationNumpyInput): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = sample[DefaultDataKeys.INPUT].numpy() + sample[DataKeys.INPUT] = sample[DataKeys.INPUT].numpy() return super().load_sample(sample, dataset=dataset) -class AudioClassificationPathsDataSource(PathsDataSource): +class AudioClassificationPathsInput(PathsInput): def __init__(self): super().__init__(loader=spectrogram_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) -class AudioClassificationDataFrameDataSource(LoaderDataFrameDataSource): +class AudioClassificationDataFrameInput(LoaderDataFrameInput): def __init__(self): super().__init__(spectrogram_loader) @@ -83,16 +83,16 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FILES: AudioClassificationPathsDataSource(), - DefaultDataSources.FOLDERS: AudioClassificationPathsDataSource(), - "data_frame": AudioClassificationDataFrameDataSource(), - DefaultDataSources.CSV: AudioClassificationDataFrameDataSource(), - DefaultDataSources.NUMPY: AudioClassificationNumpyDataSource(), - DefaultDataSources.TENSORS: AudioClassificationTensorDataSource(), + inputs={ + InputFormat.FILES: AudioClassificationPathsInput(), + InputFormat.FOLDERS: AudioClassificationPathsInput(), + "data_frame": AudioClassificationDataFrameInput(), + InputFormat.CSV: AudioClassificationDataFrameInput(), + InputFormat.NUMPY: AudioClassificationNumpyInput(), + InputFormat.TENSORS: AudioClassificationTensorInput(), }, deserializer=deserializer or ImageDeserializer(), - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index edf77da85b..a61a41012c 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -17,7 +17,7 @@ from torch import nn from torch.utils.data._utils.collate import default_collate -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, merge_transforms from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE @@ -34,10 +34,10 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable] spectrogram and target to a tensor, and collate the batch.""" return { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), - "post_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), + "post_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.Resize(spectrogram_size)), "collate": default_collate, } @@ -49,10 +49,10 @@ def train_default_transforms( augs = [] if time_mask_param is not None: - augs.append(ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param))) + augs.append(ApplyToKeys(DataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param))) if freq_mask_param is not None: - augs.append(ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))) + augs.append(ApplyToKeys(DataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))) if len(augs) > 0: return merge_transforms(default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)}) diff --git a/flash/audio/speech_recognition/collate.py b/flash/audio/speech_recognition/collate.py index 9ee53a4686..cbb907726b 100644 --- a/flash/audio/speech_recognition/collate.py +++ b/flash/audio/speech_recognition/collate.py @@ -16,7 +16,7 @@ import torch -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _AUDIO_AVAILABLE if _AUDIO_AVAILABLE: @@ -60,7 +60,7 @@ class DataCollatorCTCWithPadding: pad_to_multiple_of_labels: Optional[int] = None def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: - inputs = [sample[DefaultDataKeys.INPUT] for sample in samples] + inputs = [sample[DataKeys.INPUT] for sample in samples] sampling_rates = [sample["sampling_rate"] for sample in metadata] assert ( @@ -81,7 +81,7 @@ def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]] return_tensors="pt", ) - labels = [sample.get(DefaultDataKeys.TARGET, None) for sample in samples] + labels = [sample.get(DataKeys.TARGET, None) for sample in samples] # check to ensure labels exist to collate if None not in labels: with self.processor.as_target_processor(): diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index 6f1fbb48bb..2383e1146c 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -23,13 +23,7 @@ import flash from flash.core.data.data_module import DataModule -from flash.core.data.data_source import ( - DatasetDataSource, - DataSource, - DefaultDataKeys, - DefaultDataSources, - PathsDataSource, -) +from flash.core.data.io.input import DataKeys, DatasetInput, Input, InputFormat, PathsInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer @@ -57,8 +51,8 @@ def deserialize(self, sample: Any) -> Dict: buffer = io.BytesIO(audio) data, sampling_rate = librosa.load(buffer, sr=self.sampling_rate) return { - DefaultDataKeys.INPUT: data, - DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate}, + DataKeys.INPUT: data, + DataKeys.METADATA: {"sampling_rate": sampling_rate}, } @property @@ -70,20 +64,16 @@ def example_input(self) -> str: class BaseSpeechRecognition: @staticmethod def _load_sample(sample: Dict[str, Any], sampling_rate: int) -> Any: - path = sample[DefaultDataKeys.INPUT] - if ( - not os.path.isabs(path) - and DefaultDataKeys.METADATA in sample - and "root" in sample[DefaultDataKeys.METADATA] - ): - path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path) + path = sample[DataKeys.INPUT] + if not os.path.isabs(path) and DataKeys.METADATA in sample and "root" in sample[DataKeys.METADATA]: + path = os.path.join(sample[DataKeys.METADATA]["root"], path) speech_array, sampling_rate = librosa.load(path, sr=sampling_rate) - sample[DefaultDataKeys.INPUT] = speech_array - sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate} + sample[DataKeys.INPUT] = speech_array + sample[DataKeys.METADATA] = {"sampling_rate": sampling_rate} return sample -class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition): +class SpeechRecognitionFileInput(Input, BaseSpeechRecognition): def __init__(self, sampling_rate: int, filetype: Optional[str] = None): super().__init__() self.filetype = filetype @@ -108,9 +98,9 @@ def load_data( meta = {"root": os.path.dirname(file)} return [ { - DefaultDataKeys.INPUT: input_file, - DefaultDataKeys.TARGET: target, - DefaultDataKeys.METADATA: meta, + DataKeys.INPUT: input_file, + DataKeys.TARGET: target, + DataKeys.METADATA: meta, } for input_file, target in zip(dataset[input_key], dataset[target_key]) ] @@ -119,17 +109,17 @@ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: return self._load_sample(sample, self.sampling_rate) -class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource): +class SpeechRecognitionCSVInput(SpeechRecognitionFileInput): def __init__(self, sampling_rate: int): super().__init__(sampling_rate, filetype="csv") -class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource): +class SpeechRecognitionJSONInput(SpeechRecognitionFileInput): def __init__(self, sampling_rate: int): super().__init__(sampling_rate, filetype="json") -class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition): +class SpeechRecognitionDatasetInput(DatasetInput, BaseSpeechRecognition): def __init__(self, sampling_rate: int): super().__init__() @@ -141,12 +131,12 @@ def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Sequence[Ma return super().load_data(data, dataset) def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: - if isinstance(sample[DefaultDataKeys.INPUT], (str, Path)): + if isinstance(sample[DataKeys.INPUT], (str, Path)): sample = self._load_sample(sample, self.sampling_rate) return sample -class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition): +class SpeechRecognitionPathsInput(PathsInput, BaseSpeechRecognition): def __init__(self, sampling_rate: int): super().__init__(("wav", "ogg", "flac", "mat", "mp3")) @@ -171,13 +161,13 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(sampling_rate), - DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(sampling_rate), - DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(sampling_rate), - DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(sampling_rate), + inputs={ + InputFormat.CSV: SpeechRecognitionCSVInput(sampling_rate), + InputFormat.JSON: SpeechRecognitionJSONInput(sampling_rate), + InputFormat.FILES: SpeechRecognitionPathsInput(sampling_rate), + InputFormat.DATASETS: SpeechRecognitionDatasetInput(sampling_rate), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, deserializer=SpeechRecognitionDeserializer(sampling_rate), ) diff --git a/flash/core/classification.py b/flash/core/classification.py index 90e837b389..39f0bd0c80 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -19,7 +19,7 @@ from pytorch_lightning.utilities import rank_zero_warn from flash.core.adapter import AdapterTask -from flash.core.data.data_source import DefaultDataKeys, LabelsState +from flash.core.data.io.input import DataKeys, LabelsState from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires @@ -125,12 +125,12 @@ def multi_label(self) -> bool: class PredsClassificationOutput(ClassificationOutput): """A :class:`~flash.core.classification.ClassificationOutput` which gets the - :attr:`~flash.core.data.data_source.DefaultDataKeys.PREDS` from the sample. + :attr:`~flash.core.data.io.input.InputFormat.PREDS` from the sample. """ def transform(self, sample: Any) -> Any: - if isinstance(sample, Mapping) and DefaultDataKeys.PREDS in sample: - sample = sample[DefaultDataKeys.PREDS] + if isinstance(sample, Mapping) and DataKeys.PREDS in sample: + sample = sample[DataKeys.PREDS] if not isinstance(sample, torch.Tensor): sample = torch.tensor(sample) return sample @@ -258,7 +258,7 @@ def transform( self, sample: Any, ) -> Union[Classification, Classifications, Dict[str, Any]]: - pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + pred = sample[DataKeys.PREDS] if isinstance(sample, Dict) else sample pred = torch.tensor(pred) labels = None @@ -335,6 +335,6 @@ def transform( ) if self.return_filepath: - filepath = sample[DefaultDataKeys.METADATA]["filepath"] + filepath = sample[DataKeys.METADATA]["filepath"] return {"filepath": filepath, "predictions": fo_predictions} return fo_predictions diff --git a/flash/core/data/auto_dataset.py b/flash/core/data/auto_dataset.py index f301cb2e67..958671dad3 100644 --- a/flash/core/data/auto_dataset.py +++ b/flash/core/data/auto_dataset.py @@ -24,15 +24,15 @@ class BaseAutoDataset(Generic[DATA_TYPE]): - """The ``BaseAutoDataset`` class wraps the output of a call to :meth:`~flash.core.data.data_source.DataSource.load_data` - and a :class:`~fash.data.data_source.DataSource` and provides the ``_call_load_sample`` method to call - :meth:`~flash.core.data.data_source.DataSource.load_sample` with the correct + """The ``BaseAutoDataset`` class wraps the output of a call to :meth:`~flash.core.data.io.input.Input.load_data` + and a :class:`~fash.data.io.input.Input` and provides the ``_call_load_sample`` method to call + :meth:`~flash.core.data.io.input.Input.load_sample` with the correct :class:`~flash.core.data.utils.CurrentRunningStageFuncContext` for the current ``running_stage``. Inheriting classes are responsible for extracting samples from ``data`` to be given to ``_call_load_sample``. Args: - data: The output of a call to :meth:`~flash.core.data.data_source.DataSource.load_data`. - data_source: The :class:`~flash.core.data.data_source.DataSource` which has the ``load_sample`` method. + data: The output of a call to :meth:`~flash.core.data.io.input.Input.load_data`. + input: The :class:`~flash.core.data.io.input.Input` which has the ``load_sample`` method. running_stage: The current running stage. """ @@ -41,13 +41,13 @@ class BaseAutoDataset(Generic[DATA_TYPE]): def __init__( self, data: DATA_TYPE, - data_source: "flash.core.data.data_source.DataSource", + input: "flash.core.data.io.input.Input", running_stage: RunningStage, ) -> None: super().__init__() self.data = data - self.data_source = data_source + self.input = input self._running_stage = None self.running_stage = running_stage @@ -59,19 +59,19 @@ def running_stage(self) -> RunningStage: @running_stage.setter def running_stage(self, running_stage: RunningStage) -> None: from flash.core.data.data_pipeline import DataPipeline # noqa F811 - from flash.core.data.data_source import DataSource # noqa F811 # TODO: something better than this + from flash.core.data.io.input import Input # noqa F811 # TODO: something better than this self._running_stage = running_stage - self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.data_source) + self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.input) self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( - self.data_source, + self.input, DataPipeline._resolve_function_hierarchy( "load_sample", - self.data_source, + self.input, self.running_stage, - DataSource, + Input, ), ) diff --git a/flash/core/data/callback.py b/flash/core/data/callback.py index 3669914639..64bd06ad39 100644 --- a/flash/core/data/callback.py +++ b/flash/core/data/callback.py @@ -89,14 +89,14 @@ class BaseDataFetcher(FlashCallback): from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule - from flash.core.data.data_source import DataSource + from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform class CustomInputTransform(InputTransform): def __init__(**kwargs): super().__init__( - data_sources = {"inputs": DataSource()}, + inputs = {"inputs": Input()}, **kwargs, ) @@ -121,7 +121,7 @@ def from_inputs( test_data: Any, predict_data: Any, ) -> "CustomDataModule": - return cls.from_data_source( + return cls.from_input( "inputs", train_data=train_data, val_data=val_data, diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 8ece38d667..5c7d062fcd 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -40,7 +40,7 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.data_source import DataSource, DefaultDataSources +from flash.core.data.io.input import Input, InputFormat from flash.core.data.io.input_transform import DefaultInputTransform, InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset @@ -56,7 +56,7 @@ class DataModule(pl.LightningDataModule): """A basic DataModule class for all Flash tasks. This class includes references to a - :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.io.input_transform.InputTransform`, + :class:`~flash.core.data.io.input.Input`, :class:`~flash.core.data.io.input_transform.InputTransform`, :class:`~flash.core.data.io.output_transform.OutputTransform`, and a :class:`~flash.core.data.callback.BaseDataFetcher`. @@ -65,7 +65,7 @@ class DataModule(pl.LightningDataModule): val_dataset: Dataset for validating model performance during training. Defaults to None. test_dataset: Dataset to test model performance. Defaults to None. predict_dataset: Dataset for predicting. Defaults to None. - data_source: The :class:`~flash.core.data.data_source.DataSource` that was used to create the datasets. + input: The :class:`~flash.core.data.io.input.Input` that was used to create the datasets. input_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` to use when constructing the :class:`~flash.core.data.data_pipeline.DataPipeline`. If ``None``, a :class:`~flash.core.data.io.input_transform.DefaultInputTransform` will be used. @@ -94,7 +94,7 @@ def __init__( val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, - data_source: Optional[DataSource] = None, + input: Optional[Input] = None, input_transform: Optional[InputTransform] = None, output_transform: Optional[OutputTransform] = None, data_fetcher: Optional[BaseDataFetcher] = None, @@ -109,7 +109,7 @@ def __init__( if flash._IS_TESTING and torch.cuda.is_available(): batch_size = 16 - self._data_source: DataSource = data_source + self._input: Input = input self._input_tranform: Optional[InputTransform] = input_transform self._output_transform: Optional[OutputTransform] = output_transform self._viz: Optional[BaseVisualization] = None @@ -426,9 +426,9 @@ def multi_label(self) -> Optional[bool]: return multi_label_train or multi_label_val or multi_label_test @property - def data_source(self) -> Optional[DataSource]: + def input(self) -> Optional[Input]: """Property that returns the data source.""" - return self._data_source + return self._input @property def input_transform(self) -> InputTransform: @@ -445,16 +445,16 @@ def output_transform(self) -> OutputTransform: def data_pipeline(self) -> DataPipeline: """Property that returns the full data pipeline including the data source, input transform and postprocessing.""" - return DataPipeline(self.data_source, self.input_transform, self.output_transform) + return DataPipeline(self.input, self.input_transform, self.output_transform) - def available_data_sources(self) -> Sequence[str]: + def available_inputs(self) -> Sequence[str]: """Get the list of available data source names for use with this :class:`~flash.core.data.data_module.DataModule`. Returns: The list of data source names. """ - return self.input_transform.available_data_sources() + return self.input_transform.available_inputs() @staticmethod def _split_train_val( @@ -492,9 +492,9 @@ def _split_train_val( ) @classmethod - def from_data_source( + def from_input( cls, - data_source: str, + input: str, train_data: Any = None, val_data: Any = None, test_data: Any = None, @@ -512,21 +512,21 @@ def from_data_source( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to - :meth:`~flash.core.data.data_source.DataSource.load_data` (``train_data``, ``val_data``, ``test_data``, + :meth:`~flash.core.data.io.input.Input.load_data` (``train_data``, ``val_data``, ``test_data``, ``predict_data``). The data source will be resolved from the instantiated :class:`~flash.core.data.io.input_transform.InputTransform` - using :meth:`~flash.core.data.io.input_transform.InputTransform.data_source_of_name`. + using :meth:`~flash.core.data.io.input_transform.InputTransform.input_of_name`. Args: - data_source: The name of the data source to use for the - :meth:`~flash.core.data.data_source.DataSource.load_data`. - train_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use when creating + input: The name of the data source to use for the + :meth:`~flash.core.data.io.input.Input.load_data`. + train_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use when creating the train dataset. - val_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use when creating + val_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use when creating the validation dataset. - test_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use when creating + test_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use when creating the test dataset. - predict_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use when creating + predict_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use when creating the predict dataset. train_transform: The dictionary of transforms to use during training which maps :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. @@ -553,8 +553,8 @@ def from_data_source( Examples:: - data_module = DataModule.from_data_source( - DefaultDataSources.FOLDERS, + data_module = DataModule.from_input( + InputFormat.FOLDERS, train_data="train_folder", train_transform={ "to_tensor_transform": torch.as_tensor, @@ -570,9 +570,9 @@ def from_data_source( **input_transform_kwargs, ) - data_source = input_transform.data_source_of_name(data_source) + input = input_transform.input_of_name(input) - train_dataset, val_dataset, test_dataset, predict_dataset = data_source.to_datasets( + train_dataset, val_dataset, test_dataset, predict_dataset = input.to_datasets( train_data, val_data, test_data, @@ -584,7 +584,7 @@ def from_data_source( val_dataset, test_dataset, predict_dataset, - data_source=data_source, + input=input, input_transform=input_transform, data_fetcher=data_fetcher, val_split=val_split, @@ -613,8 +613,8 @@ def from_folders( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the - :class:`~flash.core.data.data_source.DataSource` of name - :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` + :class:`~flash.core.data.io.input.Input` of name + :attr:`~flash.core.data.io.input.InputFormat.FOLDERS` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -645,8 +645,8 @@ def from_folders( Returns: The constructed data module. """ - return cls.from_data_source( - DefaultDataSources.FOLDERS, + return cls.from_input( + InputFormat.FOLDERS, train_folder, val_folder, test_folder, @@ -687,8 +687,8 @@ def from_files( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files - using the :class:`~flash.core.data.data_source.DataSource` of name - :attr:`~flash.core.data.data_source.DefaultDataSources.FILES` from the passed or constructed + using the :class:`~flash.core.data.io.input.Input` of name + :attr:`~flash.core.data.io.input.InputFormat.FILES` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -722,8 +722,8 @@ def from_files( Returns: The constructed data module. """ - return cls.from_data_source( - DefaultDataSources.FILES, + return cls.from_input( + InputFormat.FILES, (train_files, train_targets), (val_files, val_targets), (test_files, test_targets), @@ -764,8 +764,8 @@ def from_tensors( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.TENSOR` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.TENSOR` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -809,8 +809,8 @@ def from_tensors( }, ) """ - return cls.from_data_source( - DefaultDataSources.TENSORS, + return cls.from_input( + InputFormat.TENSORS, (train_data, train_targets), (val_data, val_targets), (test_data, test_targets), @@ -851,8 +851,8 @@ def from_numpy( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.NUMPY` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -896,8 +896,8 @@ def from_numpy( }, ) """ - return cls.from_data_source( - DefaultDataSources.NUMPY, + return cls.from_input( + InputFormat.NUMPY, (train_data, train_targets), (val_data, val_targets), (test_data, test_targets), @@ -938,8 +938,8 @@ def from_json( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.JSON` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.JSON` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -1006,8 +1006,8 @@ def from_json( feild="data" ) """ - return cls.from_data_source( - DefaultDataSources.JSON, + return cls.from_input( + InputFormat.JSON, (train_file, input_fields, target_fields, field), (val_file, input_fields, target_fields, field), (test_file, input_fields, target_fields, field), @@ -1047,8 +1047,8 @@ def from_csv( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.CSV` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -1092,8 +1092,8 @@ def from_csv( }, ) """ - return cls.from_data_source( - DefaultDataSources.CSV, + return cls.from_input( + InputFormat.CSV, (train_file, input_fields, target_fields), (val_file, input_fields, target_fields), (test_file, input_fields, target_fields), @@ -1131,8 +1131,8 @@ def from_datasets( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.DATASETS` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.DATASETS` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -1172,8 +1172,8 @@ def from_datasets( }, ) """ - return cls.from_data_source( - DefaultDataSources.DATASETS, + return cls.from_input( + InputFormat.DATASETS, train_dataset, val_dataset, test_dataset, @@ -1212,8 +1212,8 @@ def from_fiftyone( ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given FiftyOne Datasets using the - :class:`~flash.core.data.data_source.DataSource` of name - :attr:`~flash.core.data.data_source.DefaultDataSources.FIFTYONE` + :class:`~flash.core.data.io.input.Input` of name + :attr:`~flash.core.data.io.input.InputFormat.FIFTYONE` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -1256,8 +1256,8 @@ def from_fiftyone( }, ) """ - return cls.from_data_source( - DefaultDataSources.FIFTYONE, + return cls.from_input( + InputFormat.FIFTYONE, train_dataset, val_dataset, test_dataset, @@ -1300,8 +1300,8 @@ def from_labelstudio( ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given export file and data directory using the - :class:`~flash.core.data.data_source.DataSource` of name - :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` + :class:`~flash.core.data.io.input.Input` of name + :attr:`~flash.core.data.io.input.InputFormat.FOLDERS` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -1381,8 +1381,8 @@ def from_labelstudio( "export_json": predict_export_json, "multi_label": input_transform_kwargs.get("multi_label", False), } - return cls.from_data_source( - DefaultDataSources.LABELSTUDIO, + return cls.from_input( + InputFormat.LABELSTUDIO, train_data=train_data if train_data else data, val_data=val_data, test_data=test_data, diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index b671f1347a..48f5a2f8c7 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -25,7 +25,7 @@ import flash from flash.core.data.auto_dataset import IterableAutoDataset from flash.core.data.batch import _DeserializeProcessor -from flash.core.data.data_source import DataSource +from flash.core.data.io.input import Input from flash.core.data.io.input_transform import ( _InputTransformProcessor, _InputTransformSequential, @@ -93,13 +93,13 @@ class DataPipeline: def __init__( self, - data_source: Optional[DataSource] = None, + input: Optional[Input] = None, input_transform: Optional[InputTransform] = None, output_transform: Optional[OutputTransform] = None, deserializer: Optional[Deserializer] = None, output: Optional[Output] = None, ) -> None: - self.data_source = data_source + self.input = input self._input_transform_pipeline = input_transform or DefaultInputTransform() self._output_transform = output_transform or OutputTransform() @@ -112,8 +112,8 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> :class:`.OutputTransform`, and :class:`.Output`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() - if self.data_source is not None: - self.data_source.attach_data_pipeline_state(data_pipeline_state) + if self.input is not None: + self.input.attach_data_pipeline_state(data_pipeline_state) self._input_transform_pipeline.attach_data_pipeline_state(data_pipeline_state) self._output_transform.attach_data_pipeline_state(data_pipeline_state) self._output.attach_data_pipeline_state(data_pipeline_state) @@ -570,14 +570,14 @@ def _detach_output_transform_from_model(model: "Task"): model.predict_step = model.predict_step._original def __str__(self) -> str: - data_source: DataSource = self.data_source + input: Input = self.input input_transform: InputTransform = self._input_transform_pipeline output_transform: OutputTransform = self._output_transform output: Output = self._output deserializer: Deserializer = self._deserializer return ( f"{self.__class__.__name__}(" - f"data_source={str(data_source)}, " + f"input={str(input)}, " f"deserializer={deserializer}, " f"input_transform={input_transform}, " f"output_transform={output_transform}, " diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index bd875e81ac..a767936d9e 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -11,706 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import typing import warnings -from dataclasses import dataclass -from functools import partial -from inspect import signature -from pathlib import Path -from typing import ( - Any, - Callable, - cast, - Dict, - Generic, - Iterable, - Iterator, - List, - Mapping, - Optional, - Sequence, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) -import numpy as np -import pandas as pd -import torch -from pytorch_lightning.utilities.enums import LightningEnum -from torch.nn import Module -from torch.utils.data.dataset import Dataset -from tqdm import tqdm +from pytorch_lightning.utilities import LightningEnum -from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset -from flash.core.data.properties import ProcessState, Properties -from flash.core.data.utils import CurrentRunningStageFuncContext -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires -from flash.core.utilities.stages import RunningStage +from flash.core.utilities.on_access_enum_meta import OnAccessEnumMeta -SampleCollection = None -if _FIFTYONE_AVAILABLE: - fol = lazy_import("fiftyone.core.labels") - if TYPE_CHECKING: - from fiftyone.core.collections import SampleCollection -else: - fol = None +class DefaultDataKeys(LightningEnum, metaclass=OnAccessEnumMeta): + """Deprecated since 0.6.0 and will be removed in 0.7.0. -# Credit to the PyTorchVision Team: -# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L10 -def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: - """Checks if a file is an allowed extension. - - Args: - filename (string): path to a file - extensions (tuple of strings): extensions to consider (lowercase) - - Returns: - bool: True if the filename ends with one of given extensions - """ - return filename.lower().endswith(extensions) - - -# Credit to the PyTorchVision Team: -# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L48 -def make_dataset( - directory: str, - class_to_idx: Dict[str, int], - extensions: Optional[Tuple[str, ...]] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[str, int]]: - """Generates a list of samples of a form (path_to_sample, class). - - Args: - directory (str): root dataset directory - class_to_idx (Dict[str, int]): dictionary mapping class name to class index - extensions (optional): A list of allowed extensions. - Either extensions or is_valid_file should be passed. Defaults to None. - is_valid_file (optional): A function that takes path of a file - and checks if the file is a valid file - (used to check of corrupt files) both extensions and - is_valid_file should not be passed. Defaults to None. - - Raises: - ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. - - Returns: - List[Tuple[str, int]]: samples of a form (path_to_sample, class) + Use `flash.DataKeys` instead. """ - instances = [] - directory = os.path.expanduser(directory) - both_none = extensions is None and is_valid_file is None - both_something = extensions is not None and is_valid_file is not None - if both_none or both_something: - raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") - if extensions is not None: - - def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) - - is_valid_file = cast(Callable[[str], bool], is_valid_file) - for target_class in sorted(class_to_idx.keys()): - class_index = class_to_idx[target_class] - target_dir = os.path.join(directory, target_class) - if not os.path.isdir(target_dir): - continue - for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - item = path, class_index - instances.append(item) - return instances - - -def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: - try: - len(data) - return True - except (TypeError, NotImplementedError): - return False - - -@dataclass(unsafe_hash=True, frozen=True) -class LabelsState(ProcessState): - """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to - label.""" - - labels: Optional[Sequence[str]] - - -@dataclass(unsafe_hash=True, frozen=True) -class ImageLabelsMap(ProcessState): - - labels_map: Optional[Dict[int, Tuple[int, int, int]]] - - -class DefaultDataSources(LightningEnum): - """The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in - :class:`~flash.core.data.data_module.DataModule`.""" - - FOLDERS = "folders" - FILES = "files" - NUMPY = "numpy" - TENSORS = "tensors" - CSV = "csv" - JSON = "json" - PARQUET = "parquet" - DATASETS = "datasets" - HUGGINGFACE_DATASET = "hf_dataset" - FIFTYONE = "fiftyone" - DATAFRAME = "data_frame" - LISTS = "lists" - LABELSTUDIO = "labelstudio" - - # TODO: Create a FlashEnum class??? - def __hash__(self) -> int: - return hash(self.value) - - -class DefaultDataKeys(LightningEnum): - """The ``DefaultDataKeys`` enum contains the keys that are used by built-in data sources to refer to inputs and - targets.""" INPUT = "input" PREDS = "preds" TARGET = "target" METADATA = "metadata" - # TODO: Create a FlashEnum class??? - def __hash__(self) -> int: - return hash(self.value) + def __new__(cls, value): + member = str.__new__(cls, value) + member._on_access = member.deprecate + return member + def deprecate(self): + warnings.warn( + "`DefaultDataKeys` was deprecated in 0.6.0 and will be removed in 0.7.0. Use `flash.DataKeys` instead.", + FutureWarning, + ) -class BaseDataFormat(LightningEnum): - """The base class for creating ``data_format`` for :class:`~flash.core.data.data_source.DataSource`.""" - + # TODO: Create a FlashEnum class??? def __hash__(self) -> int: return hash(self.value) - - -class MockDataset: - """The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. - - This is passed to - :meth:`~flash.core.data.data_source.DataSource.load_data` so that attributes can be set on the generated - data set. - """ - - def __init__(self): - self.metadata = {} - - def __setattr__(self, key, value): - if key != "metadata": - self.metadata[key] = value - object.__setattr__(self, key, value) - - -DATA_TYPE = TypeVar("DATA_TYPE") - - -class DataSource(Generic[DATA_TYPE], Properties, Module): - """The ``DataSource`` class encapsulates two hooks: ``load_data`` and ``load_sample``. - - The - :meth:`~flash.core.data.data_source.DataSource.to_datasets` method can then be used to automatically construct data - sets from the hooks. - """ - - @staticmethod - def load_data( - data: DATA_TYPE, - dataset: Optional[Any] = None, - ) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]: - """Given the ``data`` argument, the ``load_data`` hook produces a sequence or iterable of samples or - sample metadata. The ``data`` argument can be anything, but this method should return a sequence or iterable of - mappings from string (e.g. "input", "target", "bbox", etc.) to data (e.g. a target value) or metadata (e.g. a - filename). Where possible, any heavy data loading should be performed in - :meth:`~flash.core.data.data_source.DataSource.load_sample`. If the output is an iterable rather than a sequence - (that is, it doesn't have length) then the generated dataset will be an ``IterableDataset``. - - Args: - data: The data required to load the sequence or iterable of samples or sample metadata. - dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset - (e.g. ``num_classes``) will also be set on the generated dataset. - - Returns: - A sequence or iterable of samples or sample metadata to be used as inputs to - :meth:`~flash.core.data.data_source.DataSource.load_sample`. - - Example:: - - # data: "." - # output: [{"input": "./cat/1.png", "target": 1}, ..., {"input": "./dog/10.png", "target": 0}] - - output: Sequence[Mapping[str, Any]] = load_data(data) - - """ - return data - - @staticmethod - def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - """Given an element from the output of a call to - :meth:`~flash.core.data.data_source.DataSource.load_data`, this hook - should load a single data sample. The keys and values in the ``sample`` argument will be same as the keys and - values in the outputs of :meth:`~flash.core.data.data_source.DataSource.load_data`. - - Args: - sample: An element (sample or sample metadata) from the output of a call to - :meth:`~flash.core.data.data_source.DataSource.load_data`. - dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset - (e.g. ``num_classes``) will also be set on the generated dataset. - - Returns: - The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the - :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`. - - Example:: - - # sample: {"input": "./cat/1.png", "target": 1} - # output: {"input": PIL.Image, "target": 1} - - output: Mapping[str, Any] = load_sample(sample) - - """ - return sample - - def to_datasets( - self, - train_data: Optional[DATA_TYPE] = None, - val_data: Optional[DATA_TYPE] = None, - test_data: Optional[DATA_TYPE] = None, - predict_data: Optional[DATA_TYPE] = None, - ) -> Tuple[Optional[BaseAutoDataset], ...]: - """Construct data sets (of type :class:`~flash.core.data.auto_dataset.BaseAutoDataset`) from this data - source by calling :meth:`~flash.core.data.data_source.DataSource.load_data` with each of the ``*_data`` - arguments. If an argument is given as ``None`` then no dataset will be created for that stage (``train``, - ``val``, ``test``, ``predict``). - - Args: - train_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the - train dataset. - val_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the - validation dataset. - test_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the - test dataset. - predict_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create - the predict dataset. - - Returns: - A tuple of ``train_dataset``, ``val_dataset``, ``test_dataset``, ``predict_dataset``. If any ``*_data`` - argument is not passed to this method then the corresponding ``*_dataset`` will be ``None``. - """ - train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING) - val_dataset = self.generate_dataset(val_data, RunningStage.VALIDATING) - test_dataset = self.generate_dataset(test_data, RunningStage.TESTING) - predict_dataset = self.generate_dataset(predict_data, RunningStage.PREDICTING) - return train_dataset, val_dataset, test_dataset, predict_dataset - - def generate_dataset( - self, - data: Optional[DATA_TYPE], - running_stage: RunningStage, - ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: - """Generate a single dataset with the given input to - :meth:`~flash.core.data.data_source.DataSource.load_data` for the given ``running_stage``. - - Args: - data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the dataset. - running_stage: The running_stage for this dataset. - - Returns: - The constructed :class:`~flash.core.data.auto_dataset.BaseAutoDataset`. - """ - is_none = data is None - - if isinstance(data, Sequence): - is_none = data[0] is None - - if not is_none: - from flash.core.data.data_pipeline import DataPipeline - - mock_dataset = typing.cast(AutoDataset, MockDataset()) - with CurrentRunningStageFuncContext(running_stage, "load_data", self): - resolved_func_name = DataPipeline._resolve_function_hierarchy( - "load_data", self, running_stage, DataSource - ) - load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name) - parameters = signature(load_data).parameters - if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before - data = load_data(data, mock_dataset) - else: - data = load_data(data) - - if has_len(data): - dataset = AutoDataset(data, self, running_stage) - else: - dataset = IterableAutoDataset(data, self, running_stage) - dataset.__dict__.update(mock_dataset.metadata) - return dataset - - -SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") - - -class DatasetDataSource(DataSource[Dataset]): - """The ``DatasetDataSource`` implements default behaviours for data sources which expect the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a :class:`torch.utils.data.dataset.Dataset` - - Args: - labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.data_source.LabelsState`. - """ - - def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: - if isinstance(sample, tuple) and len(sample) == 2: - return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} - return {DefaultDataKeys.INPUT: sample} - - -class SequenceDataSource( - Generic[SEQUENCE_DATA_TYPE], - DataSource[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]], -): - """The ``SequenceDataSource`` implements default behaviours for data sources which expect the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of tuples (``(input, target)`` - where target can be ``None``). - - Args: - labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.data_source.LabelsState`. - """ - - def __init__(self, labels: Optional[Sequence[str]] = None): - super().__init__() - - self.labels = labels - - if self.labels is not None: - self.set_state(LabelsState(self.labels)) - - def load_data( - self, - data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], - dataset: Optional[Any] = None, - ) -> Sequence[Mapping[str, Any]]: - # TODO: Bring back the code to work out how many classes there are - inputs, targets = data - if targets is None: - return self.predict_load_data(data) - return [ - {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in zip(inputs, targets) - ] - - @staticmethod - def predict_load_data(data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: - return [{DefaultDataKeys.INPUT: input} for input in data] - - -class PathsDataSource(SequenceDataSource): - """The ``PathsDataSource`` implements default behaviours for data sources which expect the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be either a directory with a subdirectory for - each class or a tuple containing list of files and corresponding list of targets. - - Args: - extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``). - labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the - :class:`~flash.core.data.data_source.LabelsState`. - """ - - def __init__( - self, - extensions: Optional[Tuple[str, ...]] = None, - loader: Optional[Callable[[str], Any]] = None, - labels: Optional[Sequence[str]] = None, - ): - super().__init__(labels=labels) - - self.extensions = extensions - self.loader = loader - - @staticmethod - def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: - """Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. - - Args: - dir: Root directory path. - - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - """ - classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx - - @staticmethod - def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: - try: - return os.path.isdir(data) - except TypeError: - # data is not path-like (e.g. it may be a list of paths) - return False - - def load_data( - self, data: Union[str, Tuple[List[str], List[Any]]], dataset: Optional[Any] = None - ) -> Sequence[Mapping[str, Any]]: - if self.isdir(data): - classes, class_to_idx = self.find_classes(data) - if not classes: - return self.predict_load_data(data) - self.set_state(LabelsState(classes)) - - if dataset is not None: - dataset.num_classes = len(classes) - - data = make_dataset(data, class_to_idx, extensions=self.extensions) - return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] - elif dataset is not None: - dataset.num_classes = len(np.unique(data[1])) - - return list( - filter( - lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), - super().load_data(data, dataset), - ) - ) - - def predict_load_data( - self, data: Union[str, List[str]], dataset: Optional[Any] = None - ) -> Sequence[Mapping[str, Any]]: - if self.isdir(data): - data = [os.path.join(data, file) for file in os.listdir(data)] - - if not isinstance(data, list): - data = [data] - - data = [{DefaultDataKeys.INPUT: input} for input in data] - - return list( - filter( - lambda sample: has_file_allowed_extension(sample[DefaultDataKeys.INPUT], self.extensions), - data, - ) - ) - - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - path = sample[DefaultDataKeys.INPUT] - - if self.loader is not None: - sample[DefaultDataKeys.INPUT] = self.loader(path) - - sample[DefaultDataKeys.METADATA] = { - "filepath": path, - } - return sample - - -class LoaderDataFrameDataSource( - DataSource[Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str], Optional[str]]] -): - def __init__(self, loader: Callable[[str], Any]): - super().__init__() - - self.loader = loader - - @staticmethod - def _walk_files(root: str) -> Iterator[str]: - for root, _, files in os.walk(root): - for file in files: - yield os.path.join(root, file) - - @staticmethod - def _default_resolver(root: str, id: str): - if os.path.isabs(id): - return id - - pattern = f"*{id}*" - - try: - return str(next(Path(root).rglob(pattern))) - except StopIteration: - raise ValueError( - f"Found no matches for pattern: {pattern} in directory: {root}. File IDs should uniquely identify the " - "file to load." - ) - - @staticmethod - def _resolve_file(resolver: Callable[[str, str], str], root: str, input_key: str, row: pd.Series) -> pd.Series: - row[input_key] = resolver(root, row[input_key]) - return row - - @staticmethod - def _resolve_target(label_to_class: Dict[str, int], target_key: str, row: pd.Series) -> pd.Series: - row[target_key] = label_to_class[row[target_key]] - return row - - @staticmethod - def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> pd.Series: - row[target_keys[0]] = [row[target_key] for target_key in target_keys] - return row - - def load_data( - self, - data: Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str], Optional[str]], - dataset: Optional[Any] = None, - ) -> Sequence[Mapping[str, Any]]: - data, input_key, target_keys, root, resolver = data - - if isinstance(data, (str, Path)): - data = str(data) - data_frame = pd.read_csv(data) - if root is None: - root = os.path.dirname(data) - else: - data_frame = data - - if root is None: - root = "" - - if resolver is None: - warnings.warn("Using default resolver, this may take a while.", UserWarning) - resolver = self._default_resolver - - tqdm.pandas(desc="Resolving files") - data_frame = data_frame.progress_apply(partial(self._resolve_file, resolver, root, input_key), axis=1) - - if not self.predicting: - if isinstance(target_keys, List): - dataset.multi_label = True - dataset.num_classes = len(target_keys) - self.set_state(LabelsState(target_keys)) - data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1) - target_keys = target_keys[0] - else: - dataset.multi_label = False - if self.training: - labels = list(sorted(data_frame[target_keys].unique())) - dataset.num_classes = len(labels) - self.set_state(LabelsState(labels)) - - labels = self.get_state(LabelsState) - - if labels is not None: - labels = labels.labels - label_to_class = {v: k for k, v in enumerate(labels)} - data_frame = data_frame.apply(partial(self._resolve_target, label_to_class, target_keys), axis=1) - - return [ - { - DefaultDataKeys.INPUT: row[input_key], - DefaultDataKeys.TARGET: row[target_keys], - } - for _, row in data_frame.iterrows() - ] - return [ - { - DefaultDataKeys.INPUT: row[input_key], - } - for _, row in data_frame.iterrows() - ] - - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - # TODO: simplify this duplicated code from PathsDataSource - path = sample[DefaultDataKeys.INPUT] - - if self.loader is not None: - sample[DefaultDataKeys.INPUT] = self.loader(path) - - sample[DefaultDataKeys.METADATA] = { - "filepath": path, - } - return sample - - -class TensorDataSource(SequenceDataSource[torch.Tensor]): - """The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of ``torch.Tensor`` objects.""" - - def load_data( - self, - data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], - dataset: Optional[Any] = None, - ) -> Sequence[Mapping[str, Any]]: - # TODO: Bring back the code to work out how many classes there are - if len(data) == 2: - dataset.num_classes = len(torch.unique(torch.tensor(data[1]))) - return super().load_data(data, dataset) - - -class NumpyDataSource(SequenceDataSource[np.ndarray]): - """The ``NumpyDataSource`` is a ``SequenceDataSource`` which expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a sequence of ``np.ndarray`` objects.""" - - -class FiftyOneDataSource(DataSource[SampleCollection]): - """The ``FiftyOneDataSource`` expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a ``fiftyone.core.collections.SampleCollection``.""" - - def __init__(self, label_field: str = "ground_truth"): - super().__init__() - self.label_field = label_field - - @property - @requires("fiftyone") - def label_cls(self): - return fol.Label - - @requires("fiftyone") - def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: - self._validate(data) - - label_path = data._get_label_field_path(self.label_field, "label")[1] - - filepaths = data.values("filepath") - targets = data.values(label_path) - - classes = self._get_classes(data) - - if dataset is not None: - dataset.num_classes = len(classes) - - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - - if targets and isinstance(targets[0], list): - - def to_idx(t): - return [class_to_idx[x] for x in t] - - else: - - def to_idx(t): - return class_to_idx[t] - - return [ - { - DefaultDataKeys.INPUT: f, - DefaultDataKeys.TARGET: to_idx(t), - } - for f, t in zip(filepaths, targets) - ] - - @staticmethod - @requires("fiftyone") - def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: - return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] - - def _validate(self, data): - label_type = data._get_label_field_type(self.label_field) - if not issubclass(label_type, self.label_cls): - raise ValueError(f"Expected field '{self.label_field}' to have type {self.label_cls}; found {label_type}") - - def _get_classes(self, data): - classes = data.classes.get(self.label_field, None) - - if not classes: - classes = data.default_classes - - if not classes: - label_path = data._get_label_field_path(self.label_field, "label")[1] - classes = data.distinct(label_path) - - return classes diff --git a/flash/core/data/input_transform.py b/flash/core/data/input_transform.py index 99604c355d..c6f4b63da5 100644 --- a/flash/core/data/input_transform.py +++ b/flash/core/data/input_transform.py @@ -20,7 +20,7 @@ from torch.utils.data._utils.collate import default_collate from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import _InputTransformProcessor from flash.core.data.properties import Properties from flash.core.data.states import CollateFn @@ -157,7 +157,7 @@ def collate(self, samples: Sequence, metadata=None) -> Any: # return collate_fn.collate_fn(samples) parameters = inspect.signature(collate_fn).parameters - if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters: + if len(parameters) > 1 and DataKeys.METADATA in parameters: return collate_fn(samples, metadata) return collate_fn(samples) diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py new file mode 100644 index 0000000000..a41274b45b --- /dev/null +++ b/flash/core/data/io/input.py @@ -0,0 +1,710 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import typing +import warnings +from dataclasses import dataclass +from functools import partial +from inspect import signature +from pathlib import Path +from typing import ( + Any, + Callable, + cast, + Dict, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) + +import numpy as np +import pandas as pd +import torch +from pytorch_lightning.utilities.enums import LightningEnum +from torch.nn import Module +from torch.utils.data.dataset import Dataset +from tqdm import tqdm + +from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset +from flash.core.data.properties import ProcessState, Properties +from flash.core.data.utils import CurrentRunningStageFuncContext +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires +from flash.core.utilities.stages import RunningStage + +SampleCollection = None +if _FIFTYONE_AVAILABLE: + fol = lazy_import("fiftyone.core.labels") + if TYPE_CHECKING: + from fiftyone.core.collections import SampleCollection +else: + fol = None + + +# Credit to the PyTorchVision Team: +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L10 +def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: + """Checks if a file is an allowed extension. + + Args: + filename (string): path to a file + extensions (tuple of strings): extensions to consider (lowercase) + + Returns: + bool: True if the filename ends with one of given extensions + """ + return filename.lower().endswith(extensions) + + +# Credit to the PyTorchVision Team: +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L48 +def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + Args: + directory (str): root dataset directory + class_to_idx (Dict[str, int]): dictionary mapping class name to class index + extensions (optional): A list of allowed extensions. + Either extensions or is_valid_file should be passed. Defaults to None. + is_valid_file (optional): A function that takes path of a file + and checks if the file is a valid file + (used to check of corrupt files) both extensions and + is_valid_file should not be passed. Defaults to None. + + Raises: + ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. + + Returns: + List[Tuple[str, int]]: samples of a form (path_to_sample, class) + """ + instances = [] + directory = os.path.expanduser(directory) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = path, class_index + instances.append(item) + return instances + + +def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: + try: + len(data) + return True + except (TypeError, NotImplementedError): + return False + + +@dataclass(unsafe_hash=True, frozen=True) +class LabelsState(ProcessState): + """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to + label.""" + + labels: Optional[Sequence[str]] + + +@dataclass(unsafe_hash=True, frozen=True) +class ImageLabelsMap(ProcessState): + + labels_map: Optional[Dict[int, Tuple[int, int, int]]] + + +class InputFormat(LightningEnum): + """The ``InputFormat`` enum contains the data source names used by all of the default ``from_*`` methods in + :class:`~flash.core.data.data_module.DataModule`.""" + + FOLDERS = "folders" + FILES = "files" + NUMPY = "numpy" + TENSORS = "tensors" + CSV = "csv" + JSON = "json" + PARQUET = "parquet" + DATASETS = "datasets" + HUGGINGFACE_DATASET = "hf_dataset" + FIFTYONE = "fiftyone" + DATAFRAME = "data_frame" + LISTS = "lists" + LABELSTUDIO = "labelstudio" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + +class DataKeys(LightningEnum): + """The ``DataKeys`` enum contains the keys that are used by built-in data sources to refer to inputs and + targets.""" + + INPUT = "input" + PREDS = "preds" + TARGET = "target" + METADATA = "metadata" + + # TODO: Create a FlashEnum class??? + def __hash__(self) -> int: + return hash(self.value) + + +class BaseDataFormat(LightningEnum): + """The base class for creating ``data_format`` for :class:`~flash.core.data.io.input.Input`.""" + + def __hash__(self) -> int: + return hash(self.value) + + +class MockDataset: + """The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. + + This is passed to + :meth:`~flash.core.data.io.input.Input.load_data` so that attributes can be set on the generated + data set. + """ + + def __init__(self): + self.metadata = {} + + def __setattr__(self, key, value): + if key != "metadata": + self.metadata[key] = value + object.__setattr__(self, key, value) + + +DATA_TYPE = TypeVar("DATA_TYPE") + + +class Input(Generic[DATA_TYPE], Properties, Module): + """The ``Input`` class encapsulates two hooks: ``load_data`` and ``load_sample``. + + The + :meth:`~flash.core.data.io.input.Input.to_datasets` method can then be used to automatically construct data + sets from the hooks. + """ + + @staticmethod + def load_data( + data: DATA_TYPE, + dataset: Optional[Any] = None, + ) -> Union[Sequence[Mapping[str, Any]], Iterable[Mapping[str, Any]]]: + """Given the ``data`` argument, the ``load_data`` hook produces a sequence or iterable of samples or + sample metadata. The ``data`` argument can be anything, but this method should return a sequence or iterable of + mappings from string (e.g. "input", "target", "bbox", etc.) to data (e.g. a target value) or metadata (e.g. a + filename). Where possible, any heavy data loading should be performed in + :meth:`~flash.core.data.io.input.Input.load_sample`. If the output is an iterable rather than a sequence + (that is, it doesn't have length) then the generated dataset will be an ``IterableDataset``. + + Args: + data: The data required to load the sequence or iterable of samples or sample metadata. + dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset + (e.g. ``num_classes``) will also be set on the generated dataset. + + Returns: + A sequence or iterable of samples or sample metadata to be used as inputs to + :meth:`~flash.core.data.io.input.Input.load_sample`. + + Example:: + + # data: "." + # output: [{"input": "./cat/1.png", "target": 1}, ..., {"input": "./dog/10.png", "target": 0}] + + output: Sequence[Mapping[str, Any]] = load_data(data) + + """ + return data + + @staticmethod + def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + """Given an element from the output of a call to + :meth:`~flash.core.data.io.input.Input.load_data`, this hook + should load a single data sample. The keys and values in the ``sample`` argument will be same as the keys and + values in the outputs of :meth:`~flash.core.data.io.input.Input.load_data`. + + Args: + sample: An element (sample or sample metadata) from the output of a call to + :meth:`~flash.core.data.io.input.Input.load_data`. + dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset + (e.g. ``num_classes``) will also be set on the generated dataset. + + Returns: + The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the + :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`. + + Example:: + + # sample: {"input": "./cat/1.png", "target": 1} + # output: {"input": PIL.Image, "target": 1} + + output: Mapping[str, Any] = load_sample(sample) + + """ + return sample + + def to_datasets( + self, + train_data: Optional[DATA_TYPE] = None, + val_data: Optional[DATA_TYPE] = None, + test_data: Optional[DATA_TYPE] = None, + predict_data: Optional[DATA_TYPE] = None, + ) -> Tuple[Optional[BaseAutoDataset], ...]: + """Construct data sets (of type :class:`~flash.core.data.auto_dataset.BaseAutoDataset`) from this data + source by calling :meth:`~flash.core.data.io.input.Input.load_data` with each of the ``*_data`` arguments. + If an argument is given as ``None`` then no dataset will be created for that stage (``train``, ``val``, + ``test``, ``predict``). + + Args: + train_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use to create the + train dataset. + val_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use to create the + validation dataset. + test_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use to create the + test dataset. + predict_data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use to create + the predict dataset. + + Returns: + A tuple of ``train_dataset``, ``val_dataset``, ``test_dataset``, ``predict_dataset``. If any ``*_data`` + argument is not passed to this method then the corresponding ``*_dataset`` will be ``None``. + """ + train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING) + val_dataset = self.generate_dataset(val_data, RunningStage.VALIDATING) + test_dataset = self.generate_dataset(test_data, RunningStage.TESTING) + predict_dataset = self.generate_dataset(predict_data, RunningStage.PREDICTING) + return train_dataset, val_dataset, test_dataset, predict_dataset + + def generate_dataset( + self, + data: Optional[DATA_TYPE], + running_stage: RunningStage, + ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + """Generate a single dataset with the given input to + :meth:`~flash.core.data.io.input.Input.load_data` for the given ``running_stage``. + + Args: + data: The input to :meth:`~flash.core.data.io.input.Input.load_data` to use to create the dataset. + running_stage: The running_stage for this dataset. + + Returns: + The constructed :class:`~flash.core.data.auto_dataset.BaseAutoDataset`. + """ + is_none = data is None + + if isinstance(data, Sequence): + is_none = data[0] is None + + if not is_none: + from flash.core.data.data_pipeline import DataPipeline + + mock_dataset = typing.cast(AutoDataset, MockDataset()) + with CurrentRunningStageFuncContext(running_stage, "load_data", self): + resolved_func_name = DataPipeline._resolve_function_hierarchy("load_data", self, running_stage, Input) + load_data: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr(self, resolved_func_name) + parameters = signature(load_data).parameters + if len(parameters) > 1 and "dataset" in parameters: # TODO: This was DATASET_KEY before + data = load_data(data, mock_dataset) + else: + data = load_data(data) + + if has_len(data): + dataset = AutoDataset(data, self, running_stage) + else: + dataset = IterableAutoDataset(data, self, running_stage) + dataset.__dict__.update(mock_dataset.metadata) + return dataset + + +SEQUENCE_DATA_TYPE = TypeVar("SEQUENCE_DATA_TYPE") + + +class DatasetInput(Input[Dataset]): + """The ``DatasetInput`` implements default behaviours for data sources which expect the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset` + + Args: + labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the + :class:`~flash.core.data.io.input.LabelsState`. + """ + + def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: + if isinstance(sample, tuple) and len(sample) == 2: + return {DataKeys.INPUT: sample[0], DataKeys.TARGET: sample[1]} + return {DataKeys.INPUT: sample} + + +class SequenceInput( + Generic[SEQUENCE_DATA_TYPE], + Input[Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]]], +): + """The ``SequenceInput`` implements default behaviours for data sources which expect the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a sequence of tuples (``(input, target)`` + where target can be ``None``). + + Args: + labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the + :class:`~flash.core.data.io.input.LabelsState`. + """ + + def __init__(self, labels: Optional[Sequence[str]] = None): + super().__init__() + + self.labels = labels + + if self.labels is not None: + self.set_state(LabelsState(self.labels)) + + def load_data( + self, + data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + # TODO: Bring back the code to work out how many classes there are + inputs, targets = data + if targets is None: + return self.predict_load_data(data) + return [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in zip(inputs, targets)] + + @staticmethod + def predict_load_data(data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: + return [{DataKeys.INPUT: input} for input in data] + + +class PathsInput(SequenceInput): + """The ``PathsInput`` implements default behaviours for data sources which expect the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be either a directory with a subdirectory for + each class or a tuple containing list of files and corresponding list of targets. + + Args: + extensions: The file extensions supported by this data source (e.g. ``(".jpg", ".png")``). + labels: Optionally pass the labels as a mapping from class index to label string. These will then be set as the + :class:`~flash.core.data.io.input.LabelsState`. + """ + + def __init__( + self, + extensions: Optional[Tuple[str, ...]] = None, + loader: Optional[Callable[[str], Any]] = None, + labels: Optional[Sequence[str]] = None, + ): + super().__init__(labels=labels) + + self.extensions = extensions + self.loader = loader + + @staticmethod + def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: + """Finds the class folders in a dataset. Ensures that no class is a subdirectory of another. + + Args: + dir: Root directory path. + + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + """ + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + @staticmethod + def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: + try: + return os.path.isdir(data) + except TypeError: + # data is not path-like (e.g. it may be a list of paths) + return False + + def load_data( + self, data: Union[str, Tuple[List[str], List[Any]]], dataset: Optional[Any] = None + ) -> Sequence[Mapping[str, Any]]: + if self.isdir(data): + classes, class_to_idx = self.find_classes(data) + if not classes: + return self.predict_load_data(data) + self.set_state(LabelsState(classes)) + + if dataset is not None: + dataset.num_classes = len(classes) + + data = make_dataset(data, class_to_idx, extensions=self.extensions) + return [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in data] + elif dataset is not None: + dataset.num_classes = len(np.unique(data[1])) + + return list( + filter( + lambda sample: has_file_allowed_extension(sample[DataKeys.INPUT], self.extensions), + super().load_data(data, dataset), + ) + ) + + def predict_load_data( + self, data: Union[str, List[str]], dataset: Optional[Any] = None + ) -> Sequence[Mapping[str, Any]]: + if self.isdir(data): + data = [os.path.join(data, file) for file in os.listdir(data)] + + if not isinstance(data, list): + data = [data] + + data = [{DataKeys.INPUT: input} for input in data] + + return list( + filter( + lambda sample: has_file_allowed_extension(sample[DataKeys.INPUT], self.extensions), + data, + ) + ) + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + path = sample[DataKeys.INPUT] + + if self.loader is not None: + sample[DataKeys.INPUT] = self.loader(path) + + sample[DataKeys.METADATA] = { + "filepath": path, + } + return sample + + +class LoaderDataFrameInput(Input[Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str], Optional[str]]]): + def __init__(self, loader: Callable[[str], Any]): + super().__init__() + + self.loader = loader + + @staticmethod + def _walk_files(root: str) -> Iterator[str]: + for root, _, files in os.walk(root): + for file in files: + yield os.path.join(root, file) + + @staticmethod + def _default_resolver(root: str, id: str): + if os.path.isabs(id): + return id + + pattern = f"*{id}*" + + try: + return str(next(Path(root).rglob(pattern))) + except StopIteration: + raise ValueError( + f"Found no matches for pattern: {pattern} in directory: {root}. File IDs should uniquely identify the " + "file to load." + ) + + @staticmethod + def _resolve_file(resolver: Callable[[str, str], str], root: str, input_key: str, row: pd.Series) -> pd.Series: + row[input_key] = resolver(root, row[input_key]) + return row + + @staticmethod + def _resolve_target(label_to_class: Dict[str, int], target_key: str, row: pd.Series) -> pd.Series: + row[target_key] = label_to_class[row[target_key]] + return row + + @staticmethod + def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> pd.Series: + row[target_keys[0]] = [row[target_key] for target_key in target_keys] + return row + + def load_data( + self, + data: Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str], Optional[str]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + data, input_key, target_keys, root, resolver = data + + if isinstance(data, (str, Path)): + data = str(data) + data_frame = pd.read_csv(data) + if root is None: + root = os.path.dirname(data) + else: + data_frame = data + + if root is None: + root = "" + + if resolver is None: + warnings.warn("Using default resolver, this may take a while.", UserWarning) + resolver = self._default_resolver + + tqdm.pandas(desc="Resolving files") + data_frame = data_frame.progress_apply(partial(self._resolve_file, resolver, root, input_key), axis=1) + + if not self.predicting: + if isinstance(target_keys, List): + dataset.multi_label = True + dataset.num_classes = len(target_keys) + self.set_state(LabelsState(target_keys)) + data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1) + target_keys = target_keys[0] + else: + dataset.multi_label = False + if self.training: + labels = list(sorted(data_frame[target_keys].unique())) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + if labels is not None: + labels = labels.labels + label_to_class = {v: k for k, v in enumerate(labels)} + data_frame = data_frame.apply(partial(self._resolve_target, label_to_class, target_keys), axis=1) + + return [ + { + DataKeys.INPUT: row[input_key], + DataKeys.TARGET: row[target_keys], + } + for _, row in data_frame.iterrows() + ] + return [ + { + DataKeys.INPUT: row[input_key], + } + for _, row in data_frame.iterrows() + ] + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + # TODO: simplify this duplicated code from PathsInput + path = sample[DataKeys.INPUT] + + if self.loader is not None: + sample[DataKeys.INPUT] = self.loader(path) + + sample[DataKeys.METADATA] = { + "filepath": path, + } + return sample + + +class TensorInput(SequenceInput[torch.Tensor]): + """The ``TensorInput`` is a ``SequenceInput`` which expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a sequence of ``torch.Tensor`` objects.""" + + def load_data( + self, + data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + # TODO: Bring back the code to work out how many classes there are + if len(data) == 2: + dataset.num_classes = len(torch.unique(torch.tensor(data[1]))) + return super().load_data(data, dataset) + + +class NumpyInput(SequenceInput[np.ndarray]): + """The ``NumpyInput`` is a ``SequenceInput`` which expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a sequence of ``np.ndarray`` objects.""" + + +class FiftyOneInput(Input[SampleCollection]): + """The ``FiftyOneInput`` expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a ``fiftyone.core.collections.SampleCollection``.""" + + def __init__(self, label_field: str = "ground_truth"): + super().__init__() + self.label_field = label_field + + @property + @requires("fiftyone") + def label_cls(self): + return fol.Label + + @requires("fiftyone") + def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + self._validate(data) + + label_path = data._get_label_field_path(self.label_field, "label")[1] + + filepaths = data.values("filepath") + targets = data.values(label_path) + + classes = self._get_classes(data) + + if dataset is not None: + dataset.num_classes = len(classes) + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + + if targets and isinstance(targets[0], list): + + def to_idx(t): + return [class_to_idx[x] for x in t] + + else: + + def to_idx(t): + return class_to_idx[t] + + return [ + { + DataKeys.INPUT: f, + DataKeys.TARGET: to_idx(t), + } + for f, t in zip(filepaths, targets) + ] + + @staticmethod + @requires("fiftyone") + def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + return [{DataKeys.INPUT: f} for f in data.values("filepath")] + + def _validate(self, data): + label_type = data._get_label_field_type(self.label_field) + if not issubclass(label_type, self.label_cls): + raise ValueError(f"Expected field '{self.label_field}' to have type {self.label_cls}; found {label_type}") + + def _get_classes(self, data): + classes = data.classes.get(self.label_field, None) + + if not classes: + classes = data.default_classes + + if not classes: + label_path = data._get_label_field_path(self.label_field, "label")[1] + classes = data.distinct(label_path) + + return classes diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index ae759ed929..1fcfecd8fc 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -21,7 +21,7 @@ from torch.utils.data._utils.collate import default_collate from flash.core.data.callback import ControlFlow, FlashCallback -from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, DatasetInput, Input, InputFormat from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState, Properties from flash.core.data.states import ( @@ -196,9 +196,9 @@ def __init__( val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - data_sources: Optional[Dict[str, "DataSource"]] = None, + inputs: Optional[Dict[str, "Input"]] = None, deserializer: Optional["Deserializer"] = None, - default_data_source: Optional[str] = None, + default_input: Optional[str] = None, ): super().__init__() @@ -225,12 +225,12 @@ def __init__( self._test_transform = convert_to_modules(self.test_transform) self._predict_transform = convert_to_modules(self.predict_transform) - if DefaultDataSources.DATASETS not in data_sources: - data_sources[DefaultDataSources.DATASETS] = DatasetDataSource() + if InputFormat.DATASETS not in inputs: + inputs[InputFormat.DATASETS] = DatasetInput() - self._data_sources = data_sources + self._inputs = inputs self._deserializer = deserializer - self._default_data_source = default_data_source + self._default_input = default_input self._callbacks: List[FlashCallback] = [] self._default_collate: Callable = default_collate @@ -268,9 +268,9 @@ def _check_transforms( return transform if isinstance(transform, list): - transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.Sequential(*transform))} + transform = {"pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, torch.nn.Sequential(*transform))} elif callable(transform): - transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, transform)} + transform = {"pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, transform)} if not isinstance(transform, Dict): raise MisconfigurationException( @@ -444,7 +444,7 @@ def collate(self, samples: Sequence, metadata=None) -> Any: # return collate_fn.collate_fn(samples) parameters = inspect.signature(collate_fn).parameters - if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters: + if len(parameters) > 1 and DataKeys.METADATA in parameters: return collate_fn(samples, metadata) return collate_fn(samples) @@ -473,37 +473,37 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self._apply_process_state_transform(PerBatchTransformOnDevice, batch=batch) - def available_data_sources(self) -> Sequence[str]: + def available_inputs(self) -> Sequence[str]: """Get the list of available data source names for use with this :class:`~flash.core.data.io.input_transform.InputTransform`. Returns: The list of data source names. """ - return list(self._data_sources.keys()) + return list(self._inputs.keys()) - def data_source_of_name(self, data_source_name: str) -> DataSource: - """Get the :class:`~flash.core.data.data_source.DataSource` of the given name from the + def input_of_name(self, input_name: str) -> Input: + """Get the :class:`~flash.core.data.io.input.Input` of the given name from the :class:`~flash.core.data.io.input_transform.InputTransform`. Args: - data_source_name: The name of the data source to look up. + input_name: The name of the data source to look up. Returns: - The :class:`~flash.core.data.data_source.DataSource` of the given name. + The :class:`~flash.core.data.io.input.Input` of the given name. Raises: MisconfigurationException: If the requested data source is not configured by this :class:`~flash.core.data.io.input_transform.InputTransform`. """ - if data_source_name == "default": - data_source_name = self._default_data_source - data_sources = self._data_sources - if data_source_name in data_sources: - return data_sources[data_source_name] + if input_name == "default": + input_name = self._default_input + inputs = self._inputs + if input_name in inputs: + return inputs[input_name] raise MisconfigurationException( - f"No '{data_source_name}' data source is available for use with the {type(self)}. The available data " - f"sources are: {', '.join(self.available_data_sources())}." + f"No '{input_name}' data source is available for use with the {type(self)}. The available data " + f"sources are: {', '.join(self.available_inputs())}." ) @@ -514,16 +514,16 @@ def __init__( val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, - data_sources: Optional[Dict[str, "DataSource"]] = None, - default_data_source: Optional[str] = None, + inputs: Optional[Dict[str, "Input"]] = None, + default_input: Optional[str] = None, ): super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources=data_sources or {"default": DataSource()}, - default_data_source=default_data_source or "default", + inputs=inputs or {"default": Input()}, + default_input=default_input or "default", ) def get_state_dict(self) -> Dict[str, Any]: @@ -655,7 +655,7 @@ def __init__( def _extract_metadata( samples: List[Dict[str, Any]], ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: - metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] + metadata = [s.pop(DataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] return samples, metadata if any(m is not None for m in metadata) else None def forward(self, samples: Sequence[Any]) -> Any: @@ -689,7 +689,7 @@ def forward(self, samples: Sequence[Any]) -> Any: except TypeError: samples = self.collate_fn(samples) if metadata and isinstance(samples, dict): - samples[DefaultDataKeys.METADATA] = metadata + samples[DataKeys.METADATA] = metadata self.callback.on_collate(samples, self.stage) with self._per_batch_transform_context: diff --git a/flash/core/data/io/output_transform.py b/flash/core/data/io/output_transform.py index f9764fc3aa..af1f728a9f 100644 --- a/flash/core/data/io/output_transform.py +++ b/flash/core/data/io/output_transform.py @@ -17,7 +17,7 @@ from torch import Tensor from flash.core.data.batch import default_uncollate -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.properties import Properties from flash.core.data.utils import convert_to_modules @@ -80,8 +80,8 @@ def __init__( @staticmethod def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]: metadata = None - if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch: - metadata = batch.pop(DefaultDataKeys.METADATA, None) + if isinstance(batch, Mapping) and DataKeys.METADATA in batch: + metadata = batch.pop(DataKeys.METADATA, None) return batch, metadata def forward(self, batch: Sequence[Any]): @@ -89,7 +89,7 @@ def forward(self, batch: Sequence[Any]): uncollated = self.uncollate_fn(self.per_batch_transform(batch)) if metadata: for sample, sample_metadata in zip(uncollated, metadata): - sample[DefaultDataKeys.METADATA] = sample_metadata + sample[DataKeys.METADATA] = sample_metadata final_preds = [self.per_sample_transform(sample) for sample in uncollated] diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index 69b370c278..5ace9ba78f 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -319,7 +319,7 @@ def _verify_flash_dataset_enum(cls, enum: LightningEnum) -> None: if not cls.flash_datasets_registry or not isinstance(cls.flash_datasets_registry, FlashRegistry): raise MisconfigurationException( "The ``AutoContainer`` should have ``flash_datasets_registry`` (FlashRegistry) populated " - "with datasource class and ``default_flash_dataset_enum`` (LightningEnum) class attributes. " + "with Input class and ``default_flash_dataset_enum`` (LightningEnum) class attributes. " ) if enum not in cls.flash_datasets_registry.available_keys(): diff --git a/flash/core/data/output.py b/flash/core/data/output.py index ad91f4494e..644c778f8c 100644 --- a/flash/core/data/output.py +++ b/flash/core/data/output.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, List, Union -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output @@ -21,4 +21,4 @@ class Preds(Output): """A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs.""" def transform(self, sample: Any) -> Union[int, List[int]]: - return sample.get(DefaultDataKeys.PREDS, sample) if isinstance(sample, dict) else sample + return sample.get(DataKeys.PREDS, sample) if isinstance(sample, dict) else sample diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 37244f5b1b..4810c9fca7 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -42,7 +42,7 @@ } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} -_DATASOURCE_FUNCS: Set[str] = { +_INPUT_FUNCS: Set[str] = { "load_data", "load_sample", } diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 1e6c7d48a9..e723bc2cd5 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -19,7 +19,7 @@ import flash from flash.core.adapter import Adapter from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.integrations.icevision.transforms import from_icevision_predictions, to_icevision_record from flash.core.model import Task from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -83,10 +83,10 @@ def from_task( def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): metadata = metadata or [None] * len(samples) return { - DefaultDataKeys.INPUT: collate_fn( - [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] + DataKeys.INPUT: collate_fn( + [to_icevision_record({**sample, DataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] ), - DefaultDataKeys.METADATA: metadata, + DataKeys.METADATA: metadata, } def process_train_dataset( @@ -185,16 +185,16 @@ def process_predict_dataset( return data_loader def training_step(self, batch, batch_idx) -> Any: - return self.icevision_adapter.training_step(batch[DefaultDataKeys.INPUT], batch_idx) + return self.icevision_adapter.training_step(batch[DataKeys.INPUT], batch_idx) def validation_step(self, batch, batch_idx): - return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx) + return self.icevision_adapter.validation_step(batch[DataKeys.INPUT], batch_idx) def test_step(self, batch, batch_idx): - return self.icevision_adapter.validation_step(batch[DefaultDataKeys.INPUT], batch_idx) + return self.icevision_adapter.validation_step(batch[DataKeys.INPUT], batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT]) + batch[DataKeys.PREDS] = self(batch[DataKeys.INPUT]) return batch def forward(self, batch: Any) -> Any: diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 246ace7e13..27446183c9 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -16,10 +16,10 @@ import numpy as np -from flash.core.data.data_source import DefaultDataKeys, LabelsState +from flash.core.data.io.input import DataKeys, LabelsState from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE -from flash.image.data import ImagePathsDataSource +from flash.image.data import ImagePathsInput if _ICEVISION_AVAILABLE: from icevision.core.record import BaseRecord @@ -28,22 +28,22 @@ from icevision.parsers.parser import Parser -class IceVisionPathsDataSource(ImagePathsDataSource): +class IceVisionPathsInput(ImagePathsInput): def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: return super().predict_load_data(data, dataset) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - record = sample[DefaultDataKeys.INPUT].load() + record = sample[DataKeys.INPUT].load() return from_icevision_record(record) def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord): + if isinstance(sample[DataKeys.INPUT], BaseRecord): # load the data via IceVision Base Record return self.load_sample(sample) # load the data using numpy - filepath = sample[DefaultDataKeys.INPUT] + filepath = sample[DataKeys.INPUT] sample = super().load_sample(sample) - image = np.array(sample[DefaultDataKeys.INPUT]) + image = np.array(sample[DataKeys.INPUT]) record = BaseRecord([FilepathRecordComponent()]) record.filepath = filepath @@ -52,7 +52,7 @@ def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return from_icevision_record(record) -class IceVisionParserDataSource(IceVisionPathsDataSource): +class IceVisionParserInput(IceVisionPathsInput): def __init__(self, parser: Optional[Type["Parser"]] = None): super().__init__() self.parser = parser @@ -69,7 +69,7 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq dataset.num_classes = parser.class_map.num_classes self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] + return [{DataKeys.INPUT: record} for record in records[0]] raise ValueError("The parser argument must be provided.") def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 5619dfd5af..e254ea4e80 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -15,7 +15,7 @@ from torch import nn -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires if _ICEVISION_AVAILABLE: @@ -41,7 +41,7 @@ def to_icevision_record(sample: Dict[str, Any]): record = BaseRecord([]) - metadata = sample.get(DefaultDataKeys.METADATA, None) or {} + metadata = sample.get(DataKeys.METADATA, None) or {} if "image_id" in metadata: record_id_component = RecordIDRecordComponent() @@ -51,31 +51,31 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_class_map(metadata.get("class_map", None)) record.add_component(component) - if "labels" in sample[DefaultDataKeys.TARGET]: + if "labels" in sample[DataKeys.TARGET]: labels_component = InstancesLabelsRecordComponent() - labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"]) + labels_component.add_labels_by_id(sample[DataKeys.TARGET]["labels"]) record.add_component(labels_component) - if "bboxes" in sample[DefaultDataKeys.TARGET]: + if "bboxes" in sample[DataKeys.TARGET]: bboxes = [ BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"]) - for bbox in sample[DefaultDataKeys.TARGET]["bboxes"] + for bbox in sample[DataKeys.TARGET]["bboxes"] ] component = BBoxesRecordComponent() component.set_bboxes(bboxes) record.add_component(component) - if "masks" in sample[DefaultDataKeys.TARGET]: - mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"]) + if "masks" in sample[DataKeys.TARGET]: + mask_array = MaskArray(sample[DataKeys.TARGET]["masks"]) component = MasksRecordComponent() component.set_masks(mask_array) record.add_component(component) - if "keypoints" in sample[DefaultDataKeys.TARGET]: + if "keypoints" in sample[DataKeys.TARGET]: keypoints = [] for keypoints_list, keypoints_metadata in zip( - sample[DefaultDataKeys.TARGET]["keypoints"], sample[DefaultDataKeys.TARGET]["keypoints_metadata"] + sample[DataKeys.TARGET]["keypoints"], sample[DataKeys.TARGET]["keypoints_metadata"] ): xyv = [] for keypoint in keypoints_list: @@ -86,9 +86,9 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_keypoints(keypoints) record.add_component(component) - if isinstance(sample[DefaultDataKeys.INPUT], str): + if isinstance(sample[DataKeys.INPUT], str): input_component = FilepathRecordComponent() - input_component.set_filepath(sample[DefaultDataKeys.INPUT]) + input_component.set_filepath(sample[DataKeys.INPUT]) else: if "filepath" in metadata: input_component = FilepathRecordComponent() @@ -96,7 +96,7 @@ def to_icevision_record(sample: Dict[str, Any]): else: input_component = ImageRecordComponent() input_component.composite = record - input_component.set_img(sample[DefaultDataKeys.INPUT]) + input_component.set_img(sample[DataKeys.INPUT]) record.add_component(input_component) return record @@ -161,29 +161,29 @@ def from_icevision_detection(record: "BaseRecord"): def from_icevision_record(record: "BaseRecord"): sample = { - DefaultDataKeys.METADATA: { + DataKeys.METADATA: { "size": (record.height, record.width), } } if getattr(record, "record_id", None) is not None: - sample[DefaultDataKeys.METADATA]["image_id"] = record.record_id + sample[DataKeys.METADATA]["image_id"] = record.record_id if getattr(record, "filepath", None) is not None: - sample[DefaultDataKeys.METADATA]["filepath"] = record.filepath + sample[DataKeys.METADATA]["filepath"] = record.filepath if record.img is not None: - sample[DefaultDataKeys.INPUT] = record.img + sample[DataKeys.INPUT] = record.img filepath = getattr(record, "filepath", None) if filepath is not None: - sample[DefaultDataKeys.METADATA]["filepath"] = filepath + sample[DataKeys.METADATA]["filepath"] = filepath elif record.filepath is not None: - sample[DefaultDataKeys.INPUT] = record.filepath + sample[DataKeys.INPUT] = record.filepath - sample[DefaultDataKeys.TARGET] = from_icevision_detection(record) + sample[DataKeys.TARGET] = from_icevision_detection(record) if getattr(record.detection, "class_map", None) is not None: - sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map + sample[DataKeys.METADATA]["class_map"] = record.detection.class_map return sample diff --git a/flash/core/integrations/labelstudio/data_source.py b/flash/core/integrations/labelstudio/input.py similarity index 89% rename from flash/core/integrations/labelstudio/data_source.py rename to flash/core/integrations/labelstudio/input.py index 1d6b69c8cb..737be6caee 100644 --- a/flash/core/integrations/labelstudio/data_source.py +++ b/flash/core/integrations/labelstudio/input.py @@ -7,7 +7,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset -from flash.core.data.data_source import DataSource, DefaultDataKeys, has_len +from flash.core.data.io.input import DataKeys, has_len, Input from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -16,9 +16,9 @@ from transformers import AutoTokenizer -class LabelStudioDataSource(DataSource): - """The ``LabelStudioDatasource`` expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio.""" +class LabelStudioInput(Input): + """The ``LabelStudioInput`` expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio.""" def __init__(self): super().__init__() @@ -44,7 +44,7 @@ def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) - _raw_data = json.load(f) self.multi_label = data.get("multi_label", False) self.split = data.get("split") - results, test_results, classes, data_types, tag_types = LabelStudioDataSource._load_json_data( + results, test_results, classes, data_types, tag_types = LabelStudioInput._load_json_data( _raw_data, data_folder=data_folder, multi_label=self.multi_label ) self.classes = self.classes | classes @@ -72,8 +72,8 @@ def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = # delete label from input data del sample["label"] result = { - DefaultDataKeys.INPUT: sample, - DefaultDataKeys.TARGET: label, + DataKeys.INPUT: sample, + DataKeys.TARGET: label, } return result @@ -177,9 +177,9 @@ def _load_json_data(data, data_folder, multi_label=False): return results, test_results, classes, data_types, tag_types -class LabelStudioImageClassificationDataSource(LabelStudioDataSource): - """The ``LabelStudioImageDataSource`` expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio. +class LabelStudioImageClassificationInput(LabelStudioInput): + """The ``LabelStudioImageInput`` expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio. Export data should point to image files""" def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: @@ -187,13 +187,13 @@ def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = p = sample["file_upload"] # loading image image = image_default_loader(p) - result = {DefaultDataKeys.INPUT: image, DefaultDataKeys.TARGET: self._get_labels_from_sample(sample["label"])} + result = {DataKeys.INPUT: image, DataKeys.TARGET: self._get_labels_from_sample(sample["label"])} return result -class LabelStudioTextClassificationDataSource(LabelStudioDataSource): - """The ``LabelStudioTextDataSource`` expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio. +class LabelStudioTextClassificationInput(LabelStudioInput): + """The ``LabelStudioTextInput`` expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio. Export data should point to text data """ @@ -219,9 +219,9 @@ def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = return result -class LabelStudioVideoClassificationDataSource(LabelStudioDataSource): - """The ``LabelStudioVideoDataSource`` expects the input to - :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio. +class LabelStudioVideoClassificationInput(LabelStudioInput): + """The ``LabelStudioVideoInput`` expects the input to + :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio. Export data should point to video files""" def __init__(self, video_sampler=None, clip_sampler=None, decode_audio=False, decoder: str = "pyav"): diff --git a/flash/core/integrations/labelstudio/visualizer.py b/flash/core/integrations/labelstudio/visualizer.py index a284eee10b..d8fb186362 100644 --- a/flash/core/integrations/labelstudio/visualizer.py +++ b/flash/core/integrations/labelstudio/visualizer.py @@ -23,7 +23,7 @@ def show_predictions(self, predictions): def show_tasks(self, predictions, export_json=None): """Converts predictions to tasks format.""" results = self.show_predictions(predictions) - ds = self.datamodule.data_source + ds = self.datamodule.input data_type = list(ds.data_types)[0] meta = {"ids": [], "data": [], "meta": [], "max_predictions_id": 0, "project": None} if export_json: @@ -69,7 +69,7 @@ def show_tasks(self, predictions, export_json=None): def construct_result(self, pred): """Construction Label Studio result from data source and prediction values.""" - ds = self.datamodule.data_source + ds = self.datamodule.input # get label if isinstance(pred, list): label = [list(ds.classes)[p] for p in pred] diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index 63f161cf2d..69da066074 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -20,7 +20,7 @@ from flash.core.adapter import Adapter from flash.core.data.batch import default_uncollate -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.states import CollateFn from flash.core.model import Task from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE @@ -60,9 +60,9 @@ def __init__(self, backbone): @staticmethod def _collate_fn(collate_fn, samples): - samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] + samples = [(sample[DataKeys.INPUT], sample[DataKeys.TARGET]) for sample in samples] batch = collate_fn(samples) - return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} + return {DataKeys.INPUT: batch[0], DataKeys.TARGET: batch[1]} @classmethod def from_task( @@ -95,11 +95,11 @@ def from_task( return adapter def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return self.backbone.training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return self.backbone.validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> None: @@ -108,8 +108,8 @@ def test_step(self, batch: Any, batch_idx: int) -> None: ) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - result = dict(self.backbone(batch[DefaultDataKeys.INPUT])) - result[DefaultDataKeys.INPUT] = default_uncollate(batch[DefaultDataKeys.INPUT]) + result = dict(self.backbone(batch[DataKeys.INPUT])) + result[DataKeys.INPUT] = default_uncollate(batch[DataKeys.INPUT]) return default_uncollate(result) def training_epoch_end(self, outputs) -> None: diff --git a/flash/core/integrations/pytorch_forecasting/transforms.py b/flash/core/integrations/pytorch_forecasting/transforms.py index ce193d0bcc..25db7b378d 100644 --- a/flash/core/integrations/pytorch_forecasting/transforms.py +++ b/flash/core/integrations/pytorch_forecasting/transforms.py @@ -15,7 +15,7 @@ from torch.utils.data._utils.collate import default_collate -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys def convert_predictions(predictions: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], List]: @@ -26,5 +26,5 @@ def convert_predictions(predictions: List[Dict[str, Any]]) -> Tuple[Dict[str, An unrolled_predictions.extend(prediction_batch) predictions = unrolled_predictions result = default_collate(predictions) - inputs = result.pop(DefaultDataKeys.INPUT) + inputs = result.pop(DataKeys.INPUT) return result, inputs diff --git a/flash/core/model.py b/flash/core/model.py index 6bf1ad900c..bead06191e 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -14,6 +14,7 @@ import functools import inspect import pickle +import warnings from abc import ABCMeta from copy import deepcopy from importlib import import_module @@ -38,7 +39,7 @@ import flash from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState -from flash.core.data.data_source import DataSource +from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform @@ -471,6 +472,7 @@ def predict( self, x: Any, data_source: Optional[str] = None, + input: Optional[str] = None, deserializer: Optional[Deserializer] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Any: @@ -478,7 +480,7 @@ def predict( Args: x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. - data_source: A string that indicates the format of the data source to use which will override + input: A string that indicates the format of the data source to use which will override the current data source format used deserializer: A single :class:`~flash.core.data.process.Deserializer` to deserialize the input data_pipeline: Use this to override the current data pipeline @@ -486,10 +488,17 @@ def predict( Returns: The post-processed model predictions """ + if data_source is not None: + warnings.warn( + "The `data_source` argument has been deprecated since 0.6.0 and will be removed in 0.7.0. Use `input` " + "instead.", + FutureWarning, + ) + input = data_source running_stage = RunningStage.PREDICTING - data_pipeline = self.build_data_pipeline(data_source or "default", deserializer, data_pipeline) - dataset = data_pipeline.data_source.generate_dataset(x, running_stage) + data_pipeline = self.build_data_pipeline(input or "default", deserializer, data_pipeline) + dataset = data_pipeline.input.generate_dataset(x, running_stage) dataloader = self.process_predict_dataset(dataset) x = list(dataloader.dataset) x = data_pipeline.worker_input_transform_processor(running_stage, collate_fn=dataloader.collate_fn)(x) @@ -671,7 +680,7 @@ def serializer(self, serializer: Output): def build_data_pipeline( self, - data_source: Optional[str] = None, + input: Optional[str] = None, deserializer: Optional[Deserializer] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Optional[DataPipeline]: @@ -686,7 +695,7 @@ def build_data_pipeline( - :class:`.DataPipeline` passed to this method. Args: - data_source: A string that indicates the format of the data source to use which will override + input: A string that indicates the format of the data source to use which will override the current data source format used. deserializer: deserializer to use data_pipeline: Optional highest priority source of @@ -696,7 +705,7 @@ def build_data_pipeline( Returns: The fully resolved :class:`.DataPipeline`. """ - deserializer, old_data_source, input_transform, output_transform, output = None, None, None, None, None + deserializer, old_input, input_transform, output_transform, output = None, None, None, None, None # Datamodule datamodule = None @@ -706,7 +715,7 @@ def build_data_pipeline( datamodule = self.datamodule if getattr(datamodule, "data_pipeline", None) is not None: - old_data_source = getattr(datamodule.data_pipeline, "data_source", None) + old_input = getattr(datamodule.data_pipeline, "input", None) input_transform = getattr(datamodule.data_pipeline, "_input_transform_pipeline", None) output_transform = getattr(datamodule.data_pipeline, "_output_transform", None) output = getattr(datamodule.data_pipeline, "_output", None) @@ -737,18 +746,18 @@ def build_data_pipeline( getattr(data_pipeline, "_output", None), ) - data_source = data_source or old_data_source + input = input or old_input - if isinstance(data_source, str): + if isinstance(input, str): if input_transform is None: - data_source = DataSource() # TODO: warn the user that we are not using the specified data source + input = Input() # TODO: warn the user that we are not using the specified data source else: - data_source = input_transform.data_source_of_name(data_source) + input = input_transform.input_of_name(input) if deserializer is None or type(deserializer) is Deserializer: deserializer = getattr(input_transform, "deserializer", deserializer) - data_pipeline = DataPipeline(data_source, input_transform, output_transform, deserializer, output) + data_pipeline = DataPipeline(input, input_transform, output_transform, deserializer, output) self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 1ff8574c9b..0149025979 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -3,7 +3,7 @@ import torch -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.serve import expose, ModelComponent from flash.core.serve.types.base import BaseType from flash.core.utilities.stages import RunningStage @@ -36,7 +36,7 @@ def serialize(self, outputs) -> Any: # pragma: no cover for output in outputs: result = self._output(output) if isinstance(result, Mapping): - result = result[DefaultDataKeys.PREDS] + result = result[DataKeys.PREDS] results.append(result) if len(results) == 1: return results[0] diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index f015da385f..6955972e1c 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -25,7 +25,7 @@ from pytorch_lightning import LightningModule, Trainer import flash -from flash.core.data.data_source import DefaultDataSources +from flash.core.data.io.input import InputFormat from flash.core.utilities.lightning_cli import ( class_from_function, LightningArgumentParser, @@ -126,9 +126,9 @@ def __init__( datamodule_class: The :class:`~flash.data.data_module.DataModule` class. trainer_class: An optional extension of the :class:`pytorch_lightning.Trainer` class. trainer_fn: The trainer function to run. - datasource: Use this if your ``DataModule`` is created using a classmethod. Any of: + input: Use this if your ``DataModule`` is created using a classmethod. Any of: - ``None``. The ``datamodule_class.__init__`` signature will be used. - - ``str``. One of :class:`~flash.data.data_source.DefaultDataSources`. This will use the signature of + - ``str``. One of :class:`~flash.data.io.input.InputFormat`. This will use the signature of the corresponding ``DataModule.from_*`` method. - ``Callable``. A custom method. kwargs: See the parent arguments @@ -180,15 +180,13 @@ def parse_arguments(self) -> None: def add_arguments_to_parser(self, parser) -> None: subcommands = parser.add_subcommands() - data_sources = self.local_datamodule_class.input_transform_cls().available_data_sources() + inputs = self.local_datamodule_class.input_transform_cls().available_inputs() - for data_source in data_sources: - if isinstance(data_source, DefaultDataSources): - data_source = data_source.value - if hasattr(self.local_datamodule_class, f"from_{data_source}"): - self.add_subcommand_from_function( - subcommands, getattr(self.local_datamodule_class, f"from_{data_source}") - ) + for input in inputs: + if isinstance(input, InputFormat): + input = input.value + if hasattr(self.local_datamodule_class, f"from_{input}"): + self.add_subcommand_from_function(subcommands, getattr(self.local_datamodule_class, f"from_{input}")) for datamodule_builder in self.additional_datamodule_builders: self.add_subcommand_from_function(subcommands, datamodule_builder) diff --git a/flash/core/utilities/on_access_enum_meta.py b/flash/core/utilities/on_access_enum_meta.py new file mode 100644 index 0000000000..9f8f63f2ae --- /dev/null +++ b/flash/core/utilities/on_access_enum_meta.py @@ -0,0 +1,27 @@ +from enum import Enum, EnumMeta + + +class OnAccessEnumMeta(EnumMeta): + """Enum with a hook to run a function whenever a member is accessed. + + Adapted from: + https://www.buzzphp.com/posts/how-do-i-detect-and-invoke-a-function-when-a-python-enum-member-is-accessed + """ + + def __getattribute__(cls, name): + obj = super().__getattribute__(name) + if isinstance(obj, Enum) and obj._on_access: + obj._on_access() + return obj + + def __getitem__(cls, name): + member = super().__getitem__(name) + if member._on_access: + member._on_access() + return member + + def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1): + obj = super().__call__(value, names, module=module, qualname=qualname, type=type, start=start) + if isinstance(obj, Enum) and obj._on_access: + obj._on_access() + return obj diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index e997a3279e..aaff7ed998 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -14,10 +14,10 @@ from typing import Any, Callable, Dict, Optional from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataSources +from flash.core.data.io.input import InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.utilities.imports import _GRAPH_AVAILABLE -from flash.graph.data import GraphDatasetDataSource +from flash.graph.data import GraphDatasetInput if _GRAPH_AVAILABLE: from torch_geometric.data.batch import Batch @@ -37,10 +37,10 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.DATASETS: GraphDatasetDataSource(), + inputs={ + InputFormat.DATASETS: GraphDatasetInput(), }, - default_data_source=DefaultDataSources.DATASETS, + default_input=InputFormat.DATASETS, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/graph/data.py b/flash/graph/data.py index b372b901f9..3924508aa0 100644 --- a/flash/graph/data.py +++ b/flash/graph/data.py @@ -15,7 +15,7 @@ from torch.utils.data import Dataset -from flash.core.data.data_source import DatasetDataSource +from flash.core.data.io.input import DatasetInput from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires if _GRAPH_AVAILABLE: @@ -23,7 +23,7 @@ from torch_geometric.data import Dataset as TorchGeometricDataset -class GraphDatasetDataSource(DatasetDataSource): +class GraphDatasetInput(DatasetInput): @requires("graph") def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: data = super().load_data(data, dataset=dataset) diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index ec141ecb54..9a5faeb092 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -29,7 +29,7 @@ import flash from flash.core.adapter import Adapter, AdapterTask from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector @@ -53,7 +53,7 @@ class Learn2LearnRemapLabels: class RemapLabels(Learn2LearnRemapLabels): def remap(self, data, mapping): # remap needs to be adapted to Flash API. - data[DefaultDataKeys.TARGET] = mapping(data[DefaultDataKeys.TARGET]) + data[DataKeys.TARGET] = mapping(data[DataKeys.TARGET]) return data @@ -208,7 +208,7 @@ def _default_transform(self, dataset, ways: int, shots: int, queries) -> List[Ca def _labels_to_indices(data): out = defaultdict(list) for idx, sample in enumerate(data): - label = sample[DefaultDataKeys.TARGET] + label = sample[DataKeys.TARGET] if torch.is_tensor(label): label = label.item() out[label].append(idx) @@ -312,25 +312,25 @@ def from_task( return cls(task, backbone, head, algorithm, **kwargs) def training_step(self, batch, batch_idx) -> Any: - input = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return self.model.training_step(input, batch_idx) def validation_step(self, batch, batch_idx): # Should be True only for trainer.validate if self.trainer.state.fn == TrainerFn.VALIDATING: self._algorithm_has_validated = True - input = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return self.model.validation_step(input, batch_idx) def validation_epoch_end(self, outpus: Any): self.model.validation_epoch_end(outpus) def test_step(self, batch, batch_idx): - input = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return self.model.test_step(input, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return self.model.predict_step(batch[DefaultDataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx) + return self.model.predict_step(batch[DataKeys.INPUT], batch_idx, dataloader_idx=dataloader_idx) def _sanetize_batch_size(self, batch_size: int) -> int: if batch_size != 1: @@ -506,20 +506,20 @@ def from_task( return cls(task, backbone, head) def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return Task.training_step(self._task.task, batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return Task.validation_step(self._task.task, batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return Task.test_step(self._task.task, batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch[DefaultDataKeys.PREDS] = Task.predict_step( - self._task.task, (batch[DefaultDataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + batch[DataKeys.PREDS] = Task.predict_step( + self._task.task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx ) return batch diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 45fd5a1648..35f0ddb759 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -21,20 +21,20 @@ from flash.core.data.base_viz import BaseVisualization # for viz from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource +from flash.core.data.io.input import DataKeys, InputFormat, LoaderDataFrameInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer -from flash.core.integrations.labelstudio.data_source import LabelStudioImageClassificationDataSource +from flash.core.integrations.labelstudio.input import LabelStudioImageClassificationInput from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires from flash.core.utilities.stages import RunningStage from flash.image.classification.transforms import default_transforms, train_default_transforms from flash.image.data import ( image_loader, ImageDeserializer, - ImageFiftyOneDataSource, - ImageNumpyDataSource, - ImagePathsDataSource, - ImageTensorDataSource, + ImageFiftyOneInput, + ImageNumpyInput, + ImagePathsInput, + ImageTensorInput, ) if _MATPLOTLIB_AVAILABLE: @@ -43,15 +43,15 @@ plt = None -class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource): +class ImageClassificationDataFrameInput(LoaderDataFrameInput): def __init__(self): super().__init__(image_loader) @requires("image") def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: sample = super().load_sample(sample, dataset) - w, h = sample[DefaultDataKeys.INPUT].size # WxH - sample[DefaultDataKeys.METADATA]["size"] = (h, w) + w, h = sample[DataKeys.INPUT].size # WxH + sample[DataKeys.METADATA]["size"] = (h, w) return sample @@ -65,7 +65,7 @@ class ImageClassificationInputTransform(InputTransform): predict_transform: image_size: tuple with the (heigh, width) of the images deserializer: - data_source_kwargs: Additional kwargs for the data source initializer + input_kwargs: Additional kwargs for the data source initializer """ def __init__( @@ -76,7 +76,7 @@ def __init__( predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), deserializer: Optional[Deserializer] = None, - **data_source_kwargs: Any, + **input_kwargs: Any, ): self.image_size = image_size @@ -85,18 +85,18 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FIFTYONE: ImageFiftyOneDataSource(**data_source_kwargs), - DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource(), - DefaultDataSources.NUMPY: ImageNumpyDataSource(), - DefaultDataSources.TENSORS: ImageTensorDataSource(), - "data_frame": ImageClassificationDataFrameDataSource(), - DefaultDataSources.CSV: ImageClassificationDataFrameDataSource(), - DefaultDataSources.LABELSTUDIO: LabelStudioImageClassificationDataSource(), + inputs={ + InputFormat.FIFTYONE: ImageFiftyOneInput(**input_kwargs), + InputFormat.FILES: ImagePathsInput(), + InputFormat.FOLDERS: ImagePathsInput(), + InputFormat.NUMPY: ImageNumpyInput(), + InputFormat.TENSORS: ImageTensorInput(), + "data_frame": ImageClassificationDataFrameInput(), + InputFormat.CSV: ImageClassificationDataFrameInput(), + InputFormat.LABELSTUDIO: LabelStudioImageClassificationInput(), }, deserializer=deserializer or ImageDeserializer(), - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) def get_state_dict(self) -> Dict[str, Any]: @@ -196,7 +196,7 @@ def from_data_frame( Returns: The constructed data module. """ - return cls.from_data_source( + return cls.from_input( "data_frame", (train_data_frame, input_field, target_fields, train_images_root, train_resolver), (val_data_frame, input_field, target_fields, val_images_root, val_resolver), @@ -245,8 +245,8 @@ def from_csv( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV - files using the :class:`~flash.core.data.data_source.DataSource` of name - :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` from the passed or constructed + files using the :class:`~flash.core.data.io.input.Input` of name + :attr:`~flash.core.data.io.input.InputFormat.CSV` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -295,8 +295,8 @@ def from_csv( Returns: The constructed data module. """ - return cls.from_data_source( - DefaultDataSources.CSV, + return cls.from_input( + InputFormat.CSV, (train_file, input_field, target_fields, train_images_root, train_resolver), (val_file, input_field, target_fields, val_images_root, val_resolver), (test_file, input_field, target_fields, test_images_root, test_resolver), @@ -359,9 +359,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs): # unpack images and labels if isinstance(data, list): - _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") + _img, _label = data[i][DataKeys.INPUT], data[i].get(DataKeys.TARGET, "") elif isinstance(data, dict): - _img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""] * (i + 1))[i] + _img, _label = data[DataKeys.INPUT][i], data.get(DataKeys.TARGET, [""] * (i + 1))[i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy @@ -392,4 +392,4 @@ def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningS def show_per_batch_transform(self, batch: List[Any], running_stage): win_title: str = f"{running_stage} - show_per_batch_transform" - self._show_images_and_labels(batch[0], batch[0][DefaultDataKeys.INPUT].shape[0], win_title) + self._show_images_and_labels(batch[0], batch[0][DataKeys.INPUT].shape[0], win_title) diff --git a/flash/image/classification/transforms.py b/flash/image/classification/transforms.py index 738823a56e..36e91b4c19 100644 --- a/flash/image/classification/transforms.py +++ b/flash/image/classification/transforms.py @@ -17,7 +17,7 @@ import torch from torch import nn -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, requires @@ -51,27 +51,27 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: # Better approach as all transforms are applied on tensor directly return { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, K.geometry.Resize(image_size), ), "collate": kornia_collate, "per_batch_transform_on_device": ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), ), } return { - "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)), + "pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.Resize(image_size)), "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ), "collate": kornia_collate, @@ -83,8 +83,8 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": # Better approach as all transforms are applied on tensor directly transforms = { - "post_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, K.augmentation.RandomHorizontalFlip()), + "post_tensor_transform": ApplyToKeys(DataKeys.INPUT, K.augmentation.RandomHorizontalFlip()), } else: - transforms = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.RandomHorizontalFlip())} + transforms = {"pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.RandomHorizontalFlip())} return merge_transforms(default_transforms(image_size), transforms) diff --git a/flash/image/data.py b/flash/image/data.py index d8fa784ce0..3d098e4c17 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -21,13 +21,13 @@ import torch import flash -from flash.core.data.data_source import ( - DefaultDataKeys, - FiftyOneDataSource, +from flash.core.data.io.input import ( + DataKeys, + FiftyOneInput, has_file_allowed_extension, - NumpyDataSource, - PathsDataSource, - TensorDataSource, + NumpyInput, + PathsInput, + TensorInput, ) from flash.core.data.process import Deserializer from flash.core.data.utils import image_default_loader @@ -64,7 +64,7 @@ def deserialize(self, data: str) -> Dict: buffer = BytesIO(img) img = Image.open(buffer, mode="r") return { - DefaultDataKeys.INPUT: img, + DataKeys.INPUT: img, } @property @@ -76,51 +76,51 @@ def example_input(self) -> str: def _labels_to_indices(data): out = defaultdict(list) for idx, sample in enumerate(data): - label = sample[DefaultDataKeys.TARGET] + label = sample[DataKeys.TARGET] if torch.is_tensor(label): label = label.item() out[label].append(idx) return out -class ImagePathsDataSource(PathsDataSource): +class ImagePathsInput(PathsInput): def __init__(self): super().__init__(loader=image_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) @requires("image") def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: sample = super().load_sample(sample, dataset) - w, h = sample[DefaultDataKeys.INPUT].size # WxH - sample[DefaultDataKeys.METADATA]["size"] = (h, w) + w, h = sample[DataKeys.INPUT].size # WxH + sample[DataKeys.METADATA]["size"] = (h, w) return sample -class ImageTensorDataSource(TensorDataSource): +class ImageTensorInput(TensorInput): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img = to_pil_image(sample[DefaultDataKeys.INPUT]) - sample[DefaultDataKeys.INPUT] = img + img = to_pil_image(sample[DataKeys.INPUT]) + sample[DataKeys.INPUT] = img w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = {"size": (h, w)} + sample[DataKeys.METADATA] = {"size": (h, w)} return sample -class ImageNumpyDataSource(NumpyDataSource): +class ImageNumpyInput(NumpyInput): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) - sample[DefaultDataKeys.INPUT] = img + img = to_pil_image(torch.from_numpy(sample[DataKeys.INPUT])) + sample[DataKeys.INPUT] = img w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = {"size": (h, w)} + sample[DataKeys.METADATA] = {"size": (h, w)} return sample -class ImageFiftyOneDataSource(FiftyOneDataSource): +class ImageFiftyOneInput(FiftyOneInput): @staticmethod def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img_path = sample[DefaultDataKeys.INPUT] + img_path = sample[DataKeys.INPUT] img = image_default_loader(img_path) - sample[DefaultDataKeys.INPUT] = img + sample[DataKeys.INPUT] = img w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { + sample[DataKeys.METADATA] = { "filepath": img_path, "size": (h, w), } diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9d57c35233..b117c6611c 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -15,9 +15,9 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource +from flash.core.data.io.input import DataKeys, FiftyOneInput, InputFormat from flash.core.data.io.input_transform import InputTransform -from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource +from flash.core.integrations.icevision.data import IceVisionParserInput, IceVisionPathsInput from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires @@ -103,7 +103,7 @@ def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): return output_bbox -class ObjectDetectionFiftyOneDataSource(IceVisionPathsDataSource, FiftyOneDataSource): +class ObjectDetectionFiftyOneInput(IceVisionPathsInput, FiftyOneInput): def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): super().__init__() self.label_field = label_field @@ -125,12 +125,12 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] + return [{DataKeys.INPUT: record} for record in records[0]] @staticmethod @requires("fiftyone") def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] + return [{DataKeys.INPUT: f} for f in data.values("filepath")] class ObjectDetectionInputTransform(InputTransform): @@ -142,7 +142,7 @@ def __init__( predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (128, 128), parser: Optional[Callable] = None, - **data_source_kwargs: Any, + **_kwargs: Any, ): self.image_size = image_size @@ -151,15 +151,15 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - "coco": IceVisionParserDataSource(parser=COCOBBoxParser), - "via": IceVisionParserDataSource(parser=VIABBoxParser), - "voc": IceVisionParserDataSource(parser=VOCBBoxParser), - DefaultDataSources.FILES: IceVisionPathsDataSource(), - DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser), - DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), + inputs={ + "coco": IceVisionParserInput(parser=COCOBBoxParser), + "via": IceVisionParserInput(parser=VIABBoxParser), + "voc": IceVisionParserInput(parser=VOCBBoxParser), + InputFormat.FILES: IceVisionPathsInput(), + InputFormat.FOLDERS: IceVisionParserInput(parser=parser), + InputFormat.FIFTYONE: ObjectDetectionFiftyOneInput(**_kwargs), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) self._default_collate = self._identity @@ -243,7 +243,7 @@ def from_coco( train_ann_file="annotations.json", ) """ - return cls.from_data_source( + return cls.from_input( "coco", (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, @@ -322,7 +322,7 @@ def from_voc( train_ann_file="annotations.json", ) """ - return cls.from_data_source( + return cls.from_input( "voc", (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, @@ -401,7 +401,7 @@ def from_via( train_ann_file="annotations.json", ) """ - return cls.from_data_source( + return cls.from_input( "via", (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, diff --git a/flash/image/detection/output.py b/flash/image/detection/output.py index 1b56c734d2..b52c24cbe4 100644 --- a/flash/image/detection/output.py +++ b/flash/image/detection/output.py @@ -15,7 +15,7 @@ from pytorch_lightning.utilities import rank_zero_warn -from flash.core.data.data_source import DefaultDataKeys, LabelsState +from flash.core.data.io.input import DataKeys, LabelsState from flash.core.data.io.output import Output from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires @@ -56,7 +56,7 @@ def __init__( self.set_state(LabelsState(labels)) def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]: - if DefaultDataKeys.METADATA not in sample: + if DataKeys.METADATA not in sample: raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabels output.") labels = None @@ -69,11 +69,11 @@ def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] else: rank_zero_warn("No LabelsState was found, int targets will be used as label strings", UserWarning) - height, width = sample[DefaultDataKeys.METADATA]["size"] + height, width = sample[DataKeys.METADATA]["size"] detections = [] - preds = sample[DefaultDataKeys.PREDS] + preds = sample[DataKeys.PREDS] for bbox, label, score in zip(preds["bboxes"], preds["labels"], preds["scores"]): confidence = score.tolist() @@ -104,6 +104,6 @@ def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] ) fo_predictions = fo.Detections(detections=detections) if self.return_filepath: - filepath = sample[DefaultDataKeys.METADATA]["filepath"] + filepath = sample[DataKeys.METADATA]["filepath"] return {"filepath": filepath, "predictions": fo_predictions} return fo_predictions diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index c5eda2041a..1c470e3760 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.states import ( CollateFn, PerBatchTransform, @@ -121,7 +121,7 @@ def __init__( transform, collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) to_tensor_transform = ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, transform, ) diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py index 31a880a572..af48c4d2f2 100644 --- a/flash/image/embedding/vissl/adapter.py +++ b/flash/image/embedding/vissl/adapter.py @@ -17,7 +17,7 @@ import torch.nn as nn from flash.core.adapter import Adapter -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.utilities.imports import _VISSL_AVAILABLE from flash.image.embedding.vissl.hooks import AdaptVISSLHooks @@ -169,10 +169,10 @@ def ssl_forward(self, batch) -> Any: return model_output def shared_step(self, batch: Any, train: bool = True) -> Any: - out = self.ssl_forward(batch[DefaultDataKeys.INPUT]) + out = self.ssl_forward(batch[DataKeys.INPUT]) # for moco and dino - self.task.last_batch["sample"]["input"] = batch[DefaultDataKeys.INPUT] + self.task.last_batch["sample"]["input"] = batch[DataKeys.INPUT] if "data_momentum" in batch.keys(): self.task.last_batch["sample"]["data_momentum"] = [batch["data_momentum"]] @@ -204,6 +204,6 @@ def test_step(self, batch: Any, batch_idx: int) -> None: return loss def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - input_image = batch[DefaultDataKeys.INPUT] + input_image = batch[DataKeys.INPUT] return self(input_image) diff --git a/flash/image/embedding/vissl/transforms/utilities.py b/flash/image/embedding/vissl/transforms/utilities.py index 7909cbdda2..079015f639 100644 --- a/flash/image/embedding/vissl/transforms/utilities.py +++ b/flash/image/embedding/vissl/transforms/utilities.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys def vissl_collate_helper(samples): @@ -22,7 +22,7 @@ def vissl_collate_helper(samples): for batch_ele in samples: _batch_ele_dict = {} _batch_ele_dict.update(batch_ele) - _batch_ele_dict[DefaultDataKeys.INPUT] = -1 + _batch_ele_dict[DataKeys.INPUT] = -1 result.append(_batch_ele_dict) @@ -32,13 +32,13 @@ def vissl_collate_helper(samples): def multicrop_collate_fn(samples): """Multi-crop collate function for VISSL integration. - Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT """ result = vissl_collate_helper(samples) - inputs = [[] for _ in range(len(samples[0][DefaultDataKeys.INPUT]))] + inputs = [[] for _ in range(len(samples[0][DataKeys.INPUT]))] for batch_ele in samples: - multi_crop_imgs = batch_ele[DefaultDataKeys.INPUT] + multi_crop_imgs = batch_ele[DataKeys.INPUT] for idx, crop in enumerate(multi_crop_imgs): inputs[idx].append(crop) @@ -46,7 +46,7 @@ def multicrop_collate_fn(samples): for idx, ele in enumerate(inputs): inputs[idx] = torch.stack(ele) - result[DefaultDataKeys.INPUT] = inputs + result[DataKeys.INPUT] = inputs return result @@ -54,21 +54,21 @@ def multicrop_collate_fn(samples): def simclr_collate_fn(samples): """Multi-crop collate function for VISSL integration. - Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT """ result = vissl_collate_helper(samples) inputs = [] - num_views = len(samples[0][DefaultDataKeys.INPUT]) + num_views = len(samples[0][DataKeys.INPUT]) view_idx = 0 while view_idx < num_views: for batch_ele in samples: - imgs = batch_ele[DefaultDataKeys.INPUT] + imgs = batch_ele[DataKeys.INPUT] inputs.append(imgs[view_idx]) view_idx += 1 - result[DefaultDataKeys.INPUT] = torch.stack(inputs) + result[DataKeys.INPUT] = torch.stack(inputs) return result @@ -76,15 +76,15 @@ def simclr_collate_fn(samples): def moco_collate_fn(samples): """MOCO collate function for VISSL integration. - Run custom collate on a single key since VISSL transforms affect only DefaultDataKeys.INPUT + Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT """ result = vissl_collate_helper(samples) inputs = [] for batch_ele in samples: - inputs.append(torch.stack(batch_ele[DefaultDataKeys.INPUT])) + inputs.append(torch.stack(batch_ele[DataKeys.INPUT])) - result[DefaultDataKeys.INPUT] = torch.stack(inputs).squeeze()[:, 0, :, :, :].squeeze() + result[DataKeys.INPUT] = torch.stack(inputs).squeeze()[:, 0, :, :, :].squeeze() result["data_momentum"] = torch.stack(inputs).squeeze()[:, 1, :, :, :].squeeze() return result diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index c78a03899d..a241c38ea9 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -17,13 +17,13 @@ import torch.nn as nn from torch.utils.data import Dataset -from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, DatasetInput, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE -from flash.image.data import ImagePathsDataSource +from flash.image.data import ImagePathsInput from flash.image.detection import ObjectDetectionData if _TORCHVISION_AVAILABLE: @@ -40,15 +40,13 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence """ samples = {key: [sample[key] for sample in samples] for key in samples[0]} - images, scales, paddings = ff.utils.preprocess.prepare_batch( - samples[DefaultDataKeys.INPUT], None, adaptive_batch=True - ) + images, scales, paddings = ff.utils.preprocess.prepare_batch(samples[DataKeys.INPUT], None, adaptive_batch=True) samples["scales"] = scales samples["paddings"] = paddings - if DefaultDataKeys.TARGET in samples.keys(): - targets = samples[DefaultDataKeys.TARGET] + if DataKeys.TARGET in samples.keys(): + targets = samples[DataKeys.TARGET] targets = [{"target_boxes": target["boxes"]} for target in targets] for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): @@ -57,13 +55,13 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence target["target_boxes"][:, [1, 3]] += padding[1] targets[i]["target_boxes"] = target["target_boxes"] - samples[DefaultDataKeys.TARGET] = targets - samples[DefaultDataKeys.INPUT] = images + samples[DataKeys.TARGET] = targets + samples[DataKeys.INPUT] = images return samples -class FastFaceDataSource(DatasetDataSource): +class FastFaceInput(DatasetInput): """Logic for loading from FDDBDataset.""" def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: @@ -85,12 +83,12 @@ def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: return new_data def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] + filepath = sample[DataKeys.INPUT] img = image_default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img + sample[DataKeys.INPUT] = img w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { + sample[DataKeys.METADATA] = { "filepath": filepath, "size": (h, w), } @@ -116,12 +114,12 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource(), - DefaultDataSources.DATASETS: FastFaceDataSource(), + inputs={ + InputFormat.FILES: ImagePathsInput(), + InputFormat.FOLDERS: ImagePathsInput(), + InputFormat.DATASETS: FastFaceInput(), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) def get_state_dict(self) -> Dict[str, Any]: @@ -134,9 +132,9 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): def default_transforms(self) -> Dict[str, Callable]: return { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys( - DefaultDataKeys.TARGET, + DataKeys.TARGET, nn.Sequential( ApplyToKeys("boxes", torch.as_tensor), ApplyToKeys("labels", torch.as_tensor), @@ -158,12 +156,12 @@ def per_batch_transform(batch: Any) -> Any: batch.pop("scales", None) batch.pop("paddings", None) - preds = batch[DefaultDataKeys.PREDS] + preds = batch[DataKeys.PREDS] # preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))] preds = ff.utils.preprocess.adjust_results(preds, scales, paddings) - batch[DefaultDataKeys.PREDS] = preds + batch[DataKeys.PREDS] = preds return batch diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index a4beb464d4..90267f8d38 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -16,7 +16,7 @@ import pytorch_lightning as pl import torch -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task @@ -48,7 +48,7 @@ class DetectionLabels(Output): """A :class:`.Output` which extracts predictions from sample dict.""" def transform(self, sample: Any) -> Dict[str, Any]: - return sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + return sample[DataKeys.PREDS] if isinstance(sample, Dict) else sample class FaceDetector(Task): @@ -154,7 +154,7 @@ def _compute_metrics(self, logits, targets): metric.update(pred_boxes, target_boxes) def __shared_step(self, batch, train=False) -> Any: - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] + images, targets = batch[DataKeys.INPUT], batch[DataKeys.TARGET] images = self._prepare_batch(images) logits = self.model(images) loss = self.model.compute_loss(logits, targets) @@ -190,8 +190,8 @@ def test_epoch_end(self, outputs) -> None: self.log_dict({f"test_{k}": v for k, v in metric_results.items()}, on_epoch=True) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - images = batch[DefaultDataKeys.INPUT] - batch[DefaultDataKeys.PREDS] = self(images) + images = batch[DataKeys.INPUT] + batch[DataKeys.PREDS] = self(images) return batch def configure_finetune_callback(self): diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 52ff2c197e..9586b7220d 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -15,10 +15,10 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform -from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource +from flash.core.integrations.icevision.data import IceVisionParserInput, IceVisionPathsInput from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -46,13 +46,13 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - "coco": IceVisionParserDataSource(parser=COCOMaskParser), - "voc": IceVisionParserDataSource(parser=VOCMaskParser), - DefaultDataSources.FILES: IceVisionPathsDataSource(), - DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser), + inputs={ + "coco": IceVisionParserInput(parser=COCOMaskParser), + "voc": IceVisionParserInput(parser=VOCMaskParser), + InputFormat.FILES: IceVisionPathsInput(), + InputFormat.FOLDERS: IceVisionParserInput(parser=parser), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) self._default_collate = self._identity @@ -74,7 +74,7 @@ def train_default_transforms(self) -> Optional[Dict[str, Callable]]: class InstanceSegmentationOutputTransform(OutputTransform): @staticmethod def uncollate(batch: Any) -> Any: - return batch[DefaultDataKeys.PREDS] + return batch[DataKeys.PREDS] class InstanceSegmentationData(DataModule): @@ -143,7 +143,7 @@ def from_coco( train_ann_file="annotations.json", ) """ - return cls.from_data_source( + return cls.from_input( "coco", (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, @@ -222,7 +222,7 @@ def from_voc( train_ann_file="annotations.json", ) """ - return cls.from_data_source( + return cls.from_input( "voc", (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 3885c6cc04..4eecaa19b2 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -15,9 +15,9 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataSources +from flash.core.data.io.input import InputFormat from flash.core.data.io.input_transform import InputTransform -from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource +from flash.core.integrations.icevision.data import IceVisionParserInput, IceVisionPathsInput from flash.core.integrations.icevision.transforms import default_transforms from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -44,12 +44,12 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - "coco": IceVisionParserDataSource(parser=COCOKeyPointsParser), - DefaultDataSources.FILES: IceVisionPathsDataSource(), - DefaultDataSources.FOLDERS: IceVisionParserDataSource(parser=parser), + inputs={ + "coco": IceVisionParserInput(parser=COCOKeyPointsParser), + InputFormat.FILES: IceVisionPathsInput(), + InputFormat.FOLDERS: IceVisionParserInput(parser=parser), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) self._default_collate = self._identity @@ -133,7 +133,7 @@ def from_coco( train_ann_file="annotations.json", ) """ - return cls.from_data_source( + return cls.from_input( "coco", (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 920986ce80..69eb67b783 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -24,14 +24,14 @@ from flash.core.data.base_viz import BaseVisualization # for viz from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import ( - DefaultDataKeys, - DefaultDataSources, - FiftyOneDataSource, +from flash.core.data.io.input import ( + DataKeys, + FiftyOneInput, ImageLabelsMap, - NumpyDataSource, - PathsDataSource, - TensorDataSource, + InputFormat, + NumpyInput, + PathsInput, + TensorInput, ) from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer @@ -68,23 +68,23 @@ from torchvision.datasets.folder import has_file_allowed_extension -class SemanticSegmentationNumpyDataSource(NumpyDataSource): +class SemanticSegmentationNumpyInput(NumpyInput): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() - sample[DefaultDataKeys.INPUT] = img - sample[DefaultDataKeys.METADATA] = {"size": img.shape} + img = torch.from_numpy(sample[DataKeys.INPUT]).float() + sample[DataKeys.INPUT] = img + sample[DataKeys.METADATA] = {"size": img.shape} return sample -class SemanticSegmentationTensorDataSource(TensorDataSource): +class SemanticSegmentationTensorInput(TensorInput): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img = sample[DefaultDataKeys.INPUT].float() - sample[DefaultDataKeys.INPUT] = img - sample[DefaultDataKeys.METADATA] = {"size": img.shape} + img = sample[DataKeys.INPUT].float() + sample[DataKeys.INPUT] = img + sample[DataKeys.METADATA] = {"size": img.shape} return sample -class SemanticSegmentationPathsDataSource(PathsDataSource): +class SemanticSegmentationPathsInput(PathsInput): def __init__(self): super().__init__(IMG_EXTENSIONS) @@ -127,7 +127,7 @@ def load_data( zip(input_data, target_data), ) - data = [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] + data = [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in data] return data @@ -136,17 +136,17 @@ def predict_load_data(self, data: Union[str, List[str]]): def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Tensor, torch.Size]]: # unpack data paths - img_path = sample[DefaultDataKeys.INPUT] - img_labels_path = sample[DefaultDataKeys.TARGET] + img_path = sample[DataKeys.INPUT] + img_labels_path = sample[DataKeys.TARGET] # load images directly to torch tensors img: torch.Tensor = FT.to_tensor(image_default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW - sample[DefaultDataKeys.INPUT] = img.float() - sample[DefaultDataKeys.TARGET] = img_labels.float() - sample[DefaultDataKeys.METADATA] = { + sample[DataKeys.INPUT] = img.float() + sample[DataKeys.TARGET] = img_labels.float() + sample[DataKeys.METADATA] = { "filepath": img_path, "size": img.shape, } @@ -154,18 +154,18 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: - img_path = sample[DefaultDataKeys.INPUT] + img_path = sample[DataKeys.INPUT] img = FT.to_tensor(image_default_loader(img_path)).float() - sample[DefaultDataKeys.INPUT] = img - sample[DefaultDataKeys.METADATA] = { + sample[DataKeys.INPUT] = img + sample[DataKeys.METADATA] = { "filepath": img_path, "size": img.shape, } return sample -class SemanticSegmentationFiftyOneDataSource(FiftyOneDataSource): +class SemanticSegmentationFiftyOneInput(FiftyOneInput): def __init__(self, label_field: str = "ground_truth"): super().__init__(label_field=label_field) self._fo_dataset_name = None @@ -178,20 +178,20 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se self._validate(data) self._fo_dataset_name = data.name - return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] + return [{DataKeys.INPUT: f} for f in data.values("filepath")] def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Tensor, torch.Size]]: _fo_dataset = fo.load_dataset(self._fo_dataset_name) - img_path = sample[DefaultDataKeys.INPUT] + img_path = sample[DataKeys.INPUT] fo_sample = _fo_dataset[img_path] img: torch.Tensor = FT.to_tensor(image_default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW - sample[DefaultDataKeys.INPUT] = img.float() - sample[DefaultDataKeys.TARGET] = img_labels.float() - sample[DefaultDataKeys.METADATA] = { + sample[DataKeys.INPUT] = img.float() + sample[DataKeys.TARGET] = img_labels.float() + sample[DataKeys.METADATA] = { "filepath": img_path, "size": img.shape, } @@ -199,11 +199,11 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: - img_path = sample[DefaultDataKeys.INPUT] + img_path = sample[DataKeys.INPUT] img = FT.to_tensor(image_default_loader(img_path)).float() - sample[DefaultDataKeys.INPUT] = img - sample[DefaultDataKeys.METADATA] = { + sample[DataKeys.INPUT] = img + sample[DataKeys.METADATA] = { "filepath": img_path, "size": img.shape, } @@ -213,8 +213,8 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationDeserializer(ImageDeserializer): def deserialize(self, data: str) -> torch.Tensor: result = super().deserialize(data) - result[DefaultDataKeys.INPUT] = FT.to_tensor(result[DefaultDataKeys.INPUT]) - result[DefaultDataKeys.METADATA] = {"size": result[DefaultDataKeys.INPUT].shape} + result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT]) + result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape} return result @@ -229,7 +229,7 @@ def __init__( deserializer: Optional["Deserializer"] = None, num_classes: int = None, labels_map: Dict[int, Tuple[int, int, int]] = None, - **data_source_kwargs: Any, + **input_kwargs: Any, ) -> None: """InputTransform pipeline for semantic segmentation tasks. @@ -239,7 +239,7 @@ def __init__( test_transform: Dictionary with the set of transforms to apply during testing. predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. - **data_source_kwargs: Additional arguments passed on to the data source constructors. + **input_kwargs: Additional arguments passed on to the data source constructors. """ self.image_size = image_size self.num_classes = num_classes @@ -251,15 +251,15 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FIFTYONE: SemanticSegmentationFiftyOneDataSource(**data_source_kwargs), - DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(), - DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(), - DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(), - DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), + inputs={ + InputFormat.FIFTYONE: SemanticSegmentationFiftyOneInput(**input_kwargs), + InputFormat.FILES: SemanticSegmentationPathsInput(), + InputFormat.FOLDERS: SemanticSegmentationPathsInput(), + InputFormat.TENSORS: SemanticSegmentationTensorInput(), + InputFormat.NUMPY: SemanticSegmentationNumpyInput(), }, deserializer=deserializer or SemanticSegmentationDeserializer(), - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) if labels_map: @@ -305,9 +305,9 @@ def set_block_viz_window(self, value: bool) -> None: self.data_fetcher.block_viz_window = value @classmethod - def from_data_source( + def from_input( cls, - data_source: str, + input: str, train_data: Any = None, val_data: Any = None, test_data: Any = None, @@ -338,8 +338,8 @@ def from_data_source( if flash._IS_TESTING: data_fetcher.block_viz_window = True - dm = super().from_data_source( - data_source=data_source, + dm = super().from_input( + input=input, train_data=train_data, val_data=val_data, test_data=test_data, @@ -428,8 +428,8 @@ def from_folders( train_target_folder="train_masks", ) """ - return cls.from_data_source( - DefaultDataSources.FOLDERS, + return cls.from_input( + InputFormat.FOLDERS, (train_folder, train_target_folder), (val_folder, val_target_folder), (test_folder, test_target_folder), @@ -485,8 +485,8 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) # unpack images and labels sample = data[i] if isinstance(sample, dict): - image = sample[DefaultDataKeys.INPUT] - label = sample[DefaultDataKeys.TARGET] + image = sample[DataKeys.INPUT] + label = sample[DataKeys.TARGET] elif isinstance(sample, tuple): image = sample[0] label = sample[1] diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index d00ee0e8d5..9296db60cb 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -19,7 +19,7 @@ from torchmetrics import IoU from flash.core.classification import ClassificationTask -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.output_transform import OutputTransform from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _KORNIA_AVAILABLE @@ -42,9 +42,9 @@ class SemanticSegmentationOutputTransform(OutputTransform): def per_sample_transform(self, sample: Any) -> Any: - resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear") - sample[DefaultDataKeys.PREDS] = resize(sample[DefaultDataKeys.PREDS]) - sample[DefaultDataKeys.INPUT] = resize(sample[DefaultDataKeys.INPUT]) + resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"][-2:], interpolation="bilinear") + sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS]) + sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT]) return super().per_sample_transform(sample) @@ -137,20 +137,20 @@ def __init__( self.backbone = self.head.encoder def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch_input = batch[DefaultDataKeys.INPUT] - batch[DefaultDataKeys.PREDS] = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx) + batch_input = batch[DataKeys.INPUT] + batch[DataKeys.PREDS] = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx) return batch def forward(self, x) -> torch.Tensor: diff --git a/flash/image/segmentation/output.py b/flash/image/segmentation/output.py index e8873a421d..268cc89568 100644 --- a/flash/image/segmentation/output.py +++ b/flash/image/segmentation/output.py @@ -17,7 +17,7 @@ import torch import flash -from flash.core.data.data_source import DefaultDataKeys, ImageLabelsMap +from flash.core.data.io.input import DataKeys, ImageLabelsMap from flash.core.data.io.output import Output from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, @@ -91,7 +91,7 @@ def _visualize(self, labels): plt.show() def transform(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: - preds = sample[DefaultDataKeys.PREDS] + preds = sample[DataKeys.PREDS] assert len(preds.shape) == 3, preds.shape labels = torch.argmax(preds, dim=-3) # HxW @@ -126,6 +126,6 @@ def transform(self, sample: Dict[str, torch.Tensor]) -> Union[Segmentation, Dict labels = super().transform(sample) fo_predictions = fol.Segmentation(mask=np.array(labels)) if self.return_filepath: - filepath = sample[DefaultDataKeys.METADATA]["filepath"] + filepath = sample[DataKeys.METADATA]["filepath"] return {"filepath": filepath, "predictions": fo_predictions} return fo_predictions diff --git a/flash/image/segmentation/transforms.py b/flash/image/segmentation/transforms.py index 8d2f301729..886f1e5c27 100644 --- a/flash/image/segmentation/transforms.py +++ b/flash/image/segmentation/transforms.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE if _KORNIA_AVAILABLE: @@ -39,11 +39,11 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": nn.Sequential( ApplyToKeys( - [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + [DataKeys.INPUT, DataKeys.TARGET], KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation="nearest")), ), ), - "collate": Compose([kornia_collate, ApplyToKeys(DefaultDataKeys.TARGET, prepare_target)]), + "collate": Compose([kornia_collate, ApplyToKeys(DataKeys.TARGET, prepare_target)]), } @@ -55,7 +55,7 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] { "post_tensor_transform": nn.Sequential( ApplyToKeys( - [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + [DataKeys.INPUT, DataKeys.TARGET], KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)), ), ), @@ -64,11 +64,11 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] def predict_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: - """During predict, we apply the default transforms only on DefaultDataKeys.INPUT.""" + """During predict, we apply the default transforms only on DataKeys.INPUT.""" return { "post_tensor_transform": nn.Sequential( ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, K.geometry.Resize(image_size, interpolation="nearest"), ), ), diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 73a6621f31..4201844930 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -18,12 +18,12 @@ from torch import nn from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.image.classification import ImageClassificationData -from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource +from flash.image.data import ImageNumpyInput, ImagePathsInput, ImageTensorInput from flash.image.style_transfer.utils import raise_not_supported if _TORCHVISION_AVAILABLE: @@ -33,7 +33,7 @@ def _apply_to_input( - default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], DefaultDataKeys] + default_transforms_fn, keys: Union[Sequence[DataKeys], DataKeys] ) -> Callable[..., Dict[str, ApplyToKeys]]: @functools.wraps(default_transforms_fn) def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: @@ -49,10 +49,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: class StyleTransferInputTransform(InputTransform): def __init__( self, - train_transform: Optional[Union[Dict[str, Callable]]] = None, - val_transform: Optional[Union[Dict[str, Callable]]] = None, - test_transform: Optional[Union[Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Dict[str, Callable]]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, image_size: int = 256, ): if val_transform: @@ -70,14 +70,14 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource(), - DefaultDataSources.NUMPY: ImageNumpyDataSource(), - DefaultDataSources.TENSORS: ImageTensorDataSource(), - DefaultDataSources.TENSORS: ImageTensorDataSource(), + inputs={ + InputFormat.FILES: ImagePathsInput(), + InputFormat.FOLDERS: ImagePathsInput(), + InputFormat.NUMPY: ImageNumpyInput(), + InputFormat.TENSORS: ImageTensorInput(), + InputFormat.TENSORS: ImageTensorInput(), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) def get_state_dict(self) -> Dict[str, Any]: @@ -87,7 +87,7 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) - @functools.partial(_apply_to_input, keys=DefaultDataKeys.INPUT) + @functools.partial(_apply_to_input, keys=DataKeys.INPUT) def default_transforms(self) -> Optional[Dict[str, Callable]]: if self.training: return dict( @@ -131,8 +131,8 @@ def from_folders( predict_transform=predict_transform, ) - return cls.from_data_source( - DefaultDataSources.FOLDERS, + return cls.from_input( + InputFormat.FOLDERS, train_data=train_folder, predict_data=predict_folder, input_transform=input_transform, diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 1ac19d005d..cebc3aec93 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -16,7 +16,7 @@ import torch from torch import nn -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _IMAGE_AVAILABLE @@ -153,7 +153,7 @@ def _get_perceptual_loss( return loss.PerceptualLoss(content_loss, style_loss) def training_step(self, batch: Any, batch_idx: int) -> Any: - input_image = batch[DefaultDataKeys.INPUT] + input_image = batch[DataKeys.INPUT] self.perceptual_loss.set_content_image(input_image) output_image = self(input_image) return self.perceptual_loss(output_image).total() @@ -165,5 +165,5 @@ def test_step(self, batch: Any, batch_idx: int) -> NoReturn: raise_not_supported("test") def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - input_image = batch[DefaultDataKeys.INPUT] + input_image = batch[DataKeys.INPUT] return self(input_image) diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 2e6b43f795..be27447ec7 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -4,16 +4,16 @@ from flash.core.data.base_viz import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import BaseDataFormat, DataKeys, Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer -from flash.pointcloud.detection.open3d_ml.data_sources import ( +from flash.pointcloud.detection.open3d_ml.inputs import ( PointCloudObjectDetectionDataFormat, - PointCloudObjectDetectorFoldersDataSource, + PointCloudObjectDetectorFoldersInput, ) -class PointCloudObjectDetectorDatasetDataSource(DataSource): +class PointCloudObjectDetectorDatasetInput(Input): def __init__(self, **kwargs): super().__init__() @@ -31,8 +31,8 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: sample = dataset.dataset[index] return { - DefaultDataKeys.INPUT: sample["data"], - DefaultDataKeys.METADATA: sample["attr"], + DataKeys.INPUT: sample["data"], + DataKeys.METADATA: sample["attr"], } @@ -44,7 +44,7 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, deserializer: Optional[Deserializer] = None, - **data_source_kwargs, + **_kwargs, ): super().__init__( @@ -52,12 +52,12 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.DATASETS: PointCloudObjectDetectorDatasetDataSource(**data_source_kwargs), - DefaultDataSources.FOLDERS: PointCloudObjectDetectorFoldersDataSource(**data_source_kwargs), + inputs={ + InputFormat.DATASETS: PointCloudObjectDetectorDatasetInput(**_kwargs), + InputFormat.FOLDERS: PointCloudObjectDetectorFoldersInput(**_kwargs), }, deserializer=deserializer, - default_data_source=DefaultDataSources.FOLDERS, + default_input=InputFormat.FOLDERS, ) def get_state_dict(self): @@ -99,8 +99,8 @@ def from_folders( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the - :class:`~flash.core.data.data_source.DataSource` of name - :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` + :class:`~flash.core.data.io.input.Input` of name + :attr:`~flash.core.data.io.input.InputFormat.FOLDERS` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -144,8 +144,8 @@ def from_folders( }, ) """ - return cls.from_data_source( - DefaultDataSources.FOLDERS, + return cls.from_input( + InputFormat.FOLDERS, train_folder, val_folder, test_folder, diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 7fb0500483..74600f9018 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader, Sampler from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.data.states import CollateFn from flash.core.model import Task @@ -132,9 +132,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A results = self.model(batch) boxes = self.model.inference_end(results, batch) return { - DefaultDataKeys.INPUT: getattr(batch, "point", None), - DefaultDataKeys.PREDS: boxes, - DefaultDataKeys.METADATA: [a["name"] for a in batch.attr], + DataKeys.INPUT: getattr(batch, "point", None), + DataKeys.PREDS: boxes, + DataKeys.METADATA: [a["name"] for a in batch.attr], } def forward(self, x) -> torch.Tensor: diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py index d4bd99e289..dbaddcf4c1 100644 --- a/flash/pointcloud/detection/open3d_ml/app.py +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -16,7 +16,7 @@ import flash from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: @@ -78,24 +78,22 @@ def on_done_ui(): # will not be loaded until now. self._update_bounding_boxes() - self._update_datasource_combobox() + self._update_Input_combobox() self._update_shaders_combobox() # Display "colors" by default if available, "points" if not available_attrs = self._get_available_attrs() self._set_shader(self.SOLID_NAME, force_update=True) if "colors" in available_attrs: - self._datasource_combobox.selected_text = "colors" + self._Input_combobox.selected_text = "colors" elif "points" in available_attrs: - self._datasource_combobox.selected_text = "points" + self._Input_combobox.selected_text = "points" self._dont_update_geometry = True - self._on_datasource_changed( - self._datasource_combobox.selected_text, self._datasource_combobox.selected_index - ) + self._on_Input_changed(self._Input_combobox.selected_text, self._Input_combobox.selected_index) self._update_geometry_colors() self._dont_update_geometry = False - # _datasource_combobox was empty, now isn't, re-layout. + # _Input_combobox was empty, now isn't, re-layout. self.window.set_needs_layout() self._update_geometry() @@ -156,10 +154,10 @@ def show_predictions(self, predictions): for pred in predictions: data = { - "points": pred[DefaultDataKeys.INPUT][:, :3], - "name": pred[DefaultDataKeys.METADATA], + "points": pred[DataKeys.INPUT][:, :3], + "name": pred[DataKeys.METADATA], } - bounding_box = pred[DefaultDataKeys.PREDS] + bounding_box = pred[DataKeys.PREDS] viz.visualize([data], bounding_boxes=bounding_box) diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/inputs.py similarity index 98% rename from flash/pointcloud/detection/open3d_ml/data_sources.py rename to flash/pointcloud/detection/open3d_ml/inputs.py index 6b945b0b2d..b4f7100145 100644 --- a/flash/pointcloud/detection/open3d_ml/data_sources.py +++ b/flash/pointcloud/detection/open3d_ml/inputs.py @@ -19,7 +19,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.data_source import BaseDataFormat, DataSource +from flash.core.data.io.input import BaseDataFormat, Input from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: @@ -151,7 +151,7 @@ def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None): return data, attr -class PointCloudObjectDetectorFoldersDataSource(DataSource): +class PointCloudObjectDetectorFoldersInput(Input): def __init__( self, data_format: Optional[BaseDataFormat] = None, diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 05293dffe3..04bbbb7df9 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -1,14 +1,14 @@ from typing import Any, Callable, Dict, Optional, Tuple from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer from flash.core.utilities.imports import requires from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset -class PointCloudSegmentationDatasetDataSource(DataSource): +class PointCloudSegmentationDatasetInput(Input): def load_data( self, data: Any, @@ -25,12 +25,12 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: sample = dataset.dataset[index] return { - DefaultDataKeys.INPUT: sample["data"], - DefaultDataKeys.METADATA: sample["attr"], + DataKeys.INPUT: sample["data"], + DataKeys.METADATA: sample["attr"], } -class PointCloudSegmentationFoldersDataSource(DataSource): +class PointCloudSegmentationFoldersInput(Input): @requires("pointcloud") def load_data( self, @@ -48,8 +48,8 @@ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: sample = dataset.dataset[index] return { - DefaultDataKeys.INPUT: sample["data"], - DefaultDataKeys.METADATA: sample["attr"], + DataKeys.INPUT: sample["data"], + DataKeys.METADATA: sample["attr"], } @@ -70,12 +70,12 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.DATASETS: PointCloudSegmentationDatasetDataSource(), - DefaultDataSources.FOLDERS: PointCloudSegmentationFoldersDataSource(), + inputs={ + InputFormat.DATASETS: PointCloudSegmentationDatasetInput(), + InputFormat.FOLDERS: PointCloudSegmentationFoldersInput(), }, deserializer=deserializer, - default_data_source=DefaultDataSources.FOLDERS, + default_input=InputFormat.FOLDERS, ) def get_state_dict(self): diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index a672f9f4a3..48028a0b50 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -23,7 +23,7 @@ from flash.core.classification import ClassificationTask from flash.core.data.auto_dataset import BaseAutoDataset -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.data.states import CollateFn from flash.core.finetuning import BaseFinetuning @@ -149,22 +149,22 @@ def to_loss_format(self, x: torch.Tensor) -> torch.Tensor: return x.reshape(-1, x.shape[-1]) def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + batch = (batch[DataKeys.INPUT], batch[DataKeys.INPUT]["labels"].view(-1)) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + batch = (batch[DataKeys.INPUT], batch[DataKeys.INPUT]["labels"].view(-1)) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + batch = (batch[DataKeys.INPUT], batch[DataKeys.INPUT]["labels"].view(-1)) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT]) - batch[DefaultDataKeys.TARGET] = batch[DefaultDataKeys.INPUT]["labels"] + batch[DataKeys.PREDS] = self(batch[DataKeys.INPUT]) + batch[DataKeys.TARGET] = batch[DataKeys.INPUT]["labels"] # drop sub-sampled pointclouds - batch[DefaultDataKeys.INPUT] = batch[DefaultDataKeys.INPUT]["xyz"][0] + batch[DataKeys.INPUT] = batch[DataKeys.INPUT]["xyz"][0] return batch def forward(self, x) -> torch.Tensor: diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index 45edb8bbe3..e9fbfd4c97 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -14,7 +14,7 @@ import torch from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: @@ -86,10 +86,10 @@ def show_predictions(self, predictions): for pred in predictions: predictions_visualizations.append( { - "points": pred[DefaultDataKeys.INPUT], - "labels": pred[DefaultDataKeys.TARGET], - "predictions": torch.argmax(pred[DefaultDataKeys.PREDS], axis=-1) + 1, - "name": pred[DefaultDataKeys.METADATA]["name"], + "points": pred[DataKeys.INPUT], + "labels": pred[DataKeys.TARGET], + "predictions": torch.argmax(pred[DataKeys.PREDS], axis=-1) + 1, + "name": pred[DataKeys.METADATA]["name"], } ) diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index ae025f94b9..9dcc744e3e 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -2,7 +2,7 @@ from flash.tabular.data import TabularData # noqa: F401 from flash.tabular.forecasting.data import ( # noqa: F401 TabularForecastingData, - TabularForecastingDataFrameDataSource, + TabularForecastingDataFrameInput, TabularForecastingInputTransform, ) from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401 diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 6c1ad46569..de6d2178c5 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -17,7 +17,7 @@ from torch.nn import functional as F from flash.core.classification import ClassificationTask, Probabilities -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE @@ -95,19 +95,19 @@ def forward(self, x_in) -> torch.Tensor: return self.model(x)[0] def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = batch[DefaultDataKeys.INPUT] + batch = batch[DataKeys.INPUT] return self(batch) @classmethod diff --git a/flash/tabular/data.py b/flash/tabular/data.py index b21e5e485c..c9d8fc75b6 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -21,7 +21,7 @@ from flash.core.classification import LabelsState from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer @@ -41,7 +41,7 @@ DataFrame = object -class TabularDataFrameDataSource(DataSource[DataFrame]): +class TabularDataFrameInput(Input[DataFrame]): def __init__( self, cat_cols: Optional[List[str]] = None, @@ -94,16 +94,14 @@ def common_load_data( def load_data(self, data: DataFrame, dataset: Optional[Any] = None): df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) - return [ - {DefaultDataKeys.INPUT: (c, n), DefaultDataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, target) - ] + return [{DataKeys.INPUT: (c, n), DataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, target)] def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) - return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] + return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] -class TabularCSVDataSource(TabularDataFrameDataSource): +class TabularCSVInput(TabularDataFrameInput): def load_data(self, data: str, dataset: Optional[Any] = None): return super().load_data(pd.read_csv(data), dataset=dataset) @@ -147,7 +145,7 @@ def deserialize(self, data: str) -> Any: cat_vars = np.stack(cat_vars, 1) num_vars = np.stack(num_vars, 1) - return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] + return [{DataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] @property def example_input(self) -> str: @@ -194,15 +192,15 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: TabularCSVDataSource( + inputs={ + InputFormat.CSV: TabularCSVInput( cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression ), - "data_frame": TabularDataFrameDataSource( + "data_frame": TabularDataFrameInput( cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression ), }, - default_data_source=DefaultDataSources.CSV, + default_input=InputFormat.CSV, deserializer=deserializer or TabularDeserializer( cat_cols=cat_cols, @@ -251,19 +249,19 @@ class TabularData(DataModule): @property def codes(self) -> Dict[str, str]: - return self._data_source.codes + return self._input.codes @property def num_classes(self) -> int: - return self._data_source.num_classes + return self._input.num_classes @property def cat_cols(self) -> Optional[List[str]]: - return self._data_source.cat_cols + return self._input.cat_cols @property def num_cols(self) -> Optional[List[str]]: - return self._data_source.num_cols + return self._input.num_cols @property def num_features(self) -> int: @@ -300,9 +298,7 @@ def compute_state( ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: if train_data_frame is None: - raise MisconfigurationException( - "train_data_frame is required to instantiate the TabularDataFrameDataSource" - ) + raise MisconfigurationException("train_data_frame is required to instantiate the TabularDataFrameInput") data_frames = [train_data_frame] @@ -410,7 +406,7 @@ def from_data_frame( categorical_fields=categorical_fields, ) - return cls.from_data_source( + return cls.from_input( "data_frame", train_data_frame, val_data_frame, diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 4fe7591a32..7f34ad2d78 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -19,7 +19,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState @@ -42,7 +42,7 @@ class TimeSeriesDataSetParametersState(ProcessState): time_series_dataset_parameters: Optional[Dict[str, Any]] -class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): +class TabularForecastingDataFrameInput(Input[DataFrame]): @requires("tabular") def __init__( self, @@ -50,20 +50,20 @@ def __init__( target: Optional[Union[str, List[str]]] = None, group_ids: Optional[List[str]] = None, parameters: Optional[Dict[str, Any]] = None, - **data_source_kwargs: Any, + **input_kwargs: Any, ): super().__init__() self.time_idx = time_idx self.target = target self.group_ids = group_ids - self.data_source_kwargs = data_source_kwargs + self.input_kwargs = input_kwargs self.set_state(TimeSeriesDataSetParametersState(parameters)) def load_data(self, data: DataFrame, dataset: Optional[Any] = None): if self.training: time_series_dataset = TimeSeriesDataSet( - data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs + data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.input_kwargs ) parameters = time_series_dataset.get_parameters() @@ -91,7 +91,7 @@ def load_data(self, data: DataFrame, dataset: Optional[Any] = None): return time_series_dataset def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: - return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} + return {DataKeys.INPUT: sample[0], DataKeys.TARGET: sample[1]} class TabularForecastingInputTransform(InputTransform): @@ -102,23 +102,23 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, deserializer: Optional[Deserializer] = None, - **data_source_kwargs: Any, + **input_kwargs: Any, ): - self.data_source_kwargs = data_source_kwargs + self.input_kwargs = input_kwargs super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.DATAFRAME: TabularForecastingDataFrameDataSource(**data_source_kwargs), + inputs={ + InputFormat.DATAFRAME: TabularForecastingDataFrameInput(**input_kwargs), }, deserializer=deserializer, - default_data_source=DefaultDataSources.DATAFRAME, + default_input=InputFormat.DATAFRAME, ) def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: - return {**self.transforms, **self.data_source_kwargs} + return {**self.transforms, **self.input_kwargs} @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "InputTransform": @@ -212,12 +212,12 @@ def from_data_frame( ) """ - return cls.from_data_source( + return cls.from_input( time_idx=time_idx, target=target, group_ids=group_ids, parameters=parameters, - data_source=DefaultDataSources.DATAFRAME, + input=InputFormat.DATAFRAME, train_data=train_data_frame, val_data=val_data_frame, test_data=test_data_frame, diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index f0837ad14e..7e2f9d401b 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -16,7 +16,7 @@ import torch from torch.nn import functional as F -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.regression import RegressionTask from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE @@ -90,19 +90,19 @@ def forward(self, x_in) -> torch.Tensor: return self.model(x)[0].flatten() def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = batch[DefaultDataKeys.INPUT] + batch = batch[DataKeys.INPUT] return self(batch) @classmethod diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 989112723f..aa24f8013b 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -20,7 +20,7 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LabelsState, NumpyDataSource +from flash.core.data.io.input import DataKeys, InputFormat, LabelsState, NumpyInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE @@ -32,11 +32,11 @@ Bunch = object -class TemplateNumpyDataSource(NumpyDataSource): +class TemplateNumpyInput(NumpyInput): """An example data source that records ``num_features`` on the dataset. We extend - :class:`~flash.core.data.data_source.NumpyDataSource` so that we can use ``super().load_data``. + :class:`~flash.core.data.io.input.NumpyInput` so that we can use ``super().load_data``. """ def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]: @@ -53,7 +53,7 @@ def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Seq return super().load_data(data, dataset) -class TemplateSKLearnDataSource(TemplateNumpyDataSource): +class TemplateSKLearnInput(TemplateNumpyInput): """An example data source that loads data from an sklearn data ``Bunch``.""" def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]: @@ -104,11 +104,11 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.NUMPY: TemplateNumpyDataSource(), - "sklearn": TemplateSKLearnDataSource(), + inputs={ + InputFormat.NUMPY: TemplateNumpyInput(), + "sklearn": TemplateSKLearnInput(), }, - default_data_source=DefaultDataSources.NUMPY, + default_input=InputFormat.NUMPY, ) def get_state_dict(self) -> Dict[str, Any]: @@ -140,8 +140,8 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: """ return { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, self.input_to_tensor), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, self.input_to_tensor), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), } @@ -154,9 +154,9 @@ class TemplateData(DataModule): """Creating our :class:`~flash.core.data.data_module.DataModule` is as easy as setting the ``input_transform_cls`` attribute. - We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. We'll also - add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the - ``num_features`` property for convenience. + We get the ``from_numpy`` method for free as we've configured a ``InputFormat.NUMPY`` data source. We'll also add a + ``from_sklearn`` method so that we can use our ``TemplateSKLearnInput. Finally, we define the ``num_features`` + property for convenience. """ input_transform_cls = TemplateInputTransform @@ -180,7 +180,7 @@ def from_sklearn( **input_transform_kwargs: Any, ): """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them - through to the :meth:`~flash.core.data.data_module.DataModule.from_data_source` method underneath. + through to the :meth:`~flash.core.data.data_module.DataModule.from_` method underneath. Args: train_bunch: The scikit-learn ``Bunch`` containing the train data. @@ -209,7 +209,7 @@ def from_sklearn( Returns: The constructed data module. """ - return super().from_data_source( + return super().from_input( "sklearn", train_bunch, val_bunch, diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index 66e2ee2253..5af3e36165 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -17,7 +17,7 @@ from torch import nn from flash.core.classification import ClassificationTask, Labels -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE from flash.template.classification.backbones import TEMPLATE_BACKBONES @@ -83,30 +83,30 @@ def __init__( self.head = nn.Linear(out_features, num_classes) def training_step(self, batch: Any, batch_idx: int) -> Any: - """For the training step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and - :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the + """For the training step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` and + :attr:`~flash.core.data.io.input.DataKeys.TARGET` keys from the input and forward them to the :meth:`~flash.core.model.Task.training_step`.""" - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: - """For the validation step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and - :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the + """For the validation step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` and + :attr:`~flash.core.data.io.input.DataKeys.TARGET` keys from the input and forward them to the :meth:`~flash.core.model.Task.validation_step`.""" - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: - """For the test step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and - :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the + """For the test step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` and + :attr:`~flash.core.data.io.input.DataKeys.TARGET` keys from the input and forward them to the :meth:`~flash.core.model.Task.test_step`.""" - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key - from the input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" - batch = batch[DefaultDataKeys.INPUT] + """For the predict step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` key from the + input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" + batch = batch[DataKeys.INPUT] return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) def forward(self, x) -> torch.Tensor: diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 6290caa44f..b0e4ca5841 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -23,11 +23,11 @@ from flash.core.data.auto_dataset import AutoDataset from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, LabelsState +from flash.core.data.io.input import DataKeys, Input, InputFormat, LabelsState from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.process import Deserializer -from flash.core.integrations.labelstudio.data_source import LabelStudioTextClassificationDataSource +from flash.core.integrations.labelstudio.input import LabelStudioTextClassificationInput from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: @@ -61,7 +61,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) -class TextDataSource(DataSource): +class TextInput(Input): @requires("text") def __init__(self, backbone: str, max_length: int = 128): super().__init__() @@ -86,7 +86,7 @@ def _transform_label(label_to_class_mapping: Dict[str, int], target: str, ex: Di @staticmethod def _multilabel_target(targets: List[str], element: Dict[str, Any]) -> Dict[str, Any]: targets = [element.pop(target) for target in targets] - element[DefaultDataKeys.TARGET] = targets + element[DataKeys.TARGET] = targets return element def _to_hf_dataset(self, data) -> Sequence[Mapping[str, Any]]: @@ -132,10 +132,10 @@ def load_data( hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, target)) # rename label column - hf_dataset = hf_dataset.rename_column(target, DefaultDataKeys.TARGET) + hf_dataset = hf_dataset.rename_column(target, DataKeys.TARGET) # remove extra columns - extra_columns = set(hf_dataset.column_names) - {input, DefaultDataKeys.TARGET} + extra_columns = set(hf_dataset.column_names) - {input, DataKeys.TARGET} hf_dataset = hf_dataset.remove_columns(extra_columns) # tokenize @@ -159,41 +159,41 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) -class TextCSVDataSource(TextDataSource): +class TextCSVInput(TextInput): def to_hf_dataset(self, data: Tuple[str, str, str]) -> Tuple[Sequence[Mapping[str, Any]], str, str]: file, *other = data dataset_dict = load_dataset("csv", data_files={"train": str(file)}) return (dataset_dict["train"], *other) -class TextJSONDataSource(TextDataSource): +class TextJSONInput(TextInput): def to_hf_dataset(self, data: Tuple[str, str, str, str]) -> Tuple[Sequence[Mapping[str, Any]], str, str]: file, *other, field = data dataset_dict = load_dataset("json", data_files={"train": str(file)}, field=field) return (dataset_dict["train"], *other) -class TextDataFrameDataSource(TextDataSource): +class TextDataFrameInput(TextInput): def to_hf_dataset(self, data: Tuple[DataFrame, str, str]) -> Tuple[Sequence[Mapping[str, Any]], str, str]: df, *other = data hf_dataset = Dataset.from_pandas(df) return (hf_dataset, *other) -class TextParquetDataSource(TextDataSource): +class TextParquetInput(TextInput): def to_hf_dataset(self, data: Tuple[str, str, str]) -> Tuple[Sequence[Mapping[str, Any]], str, str]: file, *other = data hf_dataset = Dataset.from_parquet(str(file)) return (hf_dataset, *other) -class TextHuggingFaceDatasetDataSource(TextDataSource): +class TextHuggingFaceDatasetInput(TextInput): def to_hf_dataset(self, data: Tuple[str, str, str]) -> Tuple[Sequence[Mapping[str, Any]], str, str]: hf_dataset, *other = data return (hf_dataset, *other) -class TextListDataSource(TextDataSource): +class TextListInput(TextInput): def to_hf_dataset( self, data: Union[Tuple[List[str], List[str]], List[str]] ) -> Tuple[Sequence[Mapping[str, Any]], Optional[List[str]]]: @@ -202,11 +202,11 @@ def to_hf_dataset( input_list, target_list = data # NOTE: here we already deal with multilabels # NOTE: here we already rename to correct column names - hf_dataset = Dataset.from_dict({DefaultDataKeys.INPUT: input_list, DefaultDataKeys.TARGET: target_list}) + hf_dataset = Dataset.from_dict({DataKeys.INPUT: input_list, DataKeys.TARGET: target_list}) return hf_dataset, target_list # predicting - hf_dataset = Dataset.from_dict({DefaultDataKeys.INPUT: data}) + hf_dataset = Dataset.from_dict({DataKeys.INPUT: data}) return (hf_dataset,) @@ -228,7 +228,7 @@ def load_data( else: dataset.multi_label = False if self.training: - labels = list(sorted(list(set(hf_dataset[DefaultDataKeys.TARGET])))) + labels = list(sorted(list(set(hf_dataset[DataKeys.TARGET])))) dataset.num_classes = len(labels) self.set_state(LabelsState(labels)) @@ -239,15 +239,13 @@ def load_data( labels = labels.labels label_to_class_mapping = {v: k for k, v in enumerate(labels)} # happens in-place and keeps the target column name - hf_dataset = hf_dataset.map( - partial(self._transform_label, label_to_class_mapping, DefaultDataKeys.TARGET) - ) + hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, DataKeys.TARGET)) # tokenize - hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=DefaultDataKeys.INPUT), batched=True) + hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=DataKeys.INPUT), batched=True) # set format - hf_dataset = hf_dataset.remove_columns([DefaultDataKeys.INPUT]) # just leave the numerical columns + hf_dataset = hf_dataset.remove_columns([DataKeys.INPUT]) # just leave the numerical columns hf_dataset.set_format("torch") return hf_dataset @@ -272,20 +270,18 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), - DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), - DefaultDataSources.PARQUET: TextParquetDataSource(self.backbone, max_length=max_length), - DefaultDataSources.HUGGINGFACE_DATASET: TextHuggingFaceDatasetDataSource( - self.backbone, max_length=max_length - ), - DefaultDataSources.DATAFRAME: TextDataFrameDataSource(self.backbone, max_length=max_length), - DefaultDataSources.LISTS: TextListDataSource(self.backbone, max_length=max_length), - DefaultDataSources.LABELSTUDIO: LabelStudioTextClassificationDataSource( + inputs={ + InputFormat.CSV: TextCSVInput(self.backbone, max_length=max_length), + InputFormat.JSON: TextJSONInput(self.backbone, max_length=max_length), + InputFormat.PARQUET: TextParquetInput(self.backbone, max_length=max_length), + InputFormat.HUGGINGFACE_DATASET: TextHuggingFaceDatasetInput(self.backbone, max_length=max_length), + InputFormat.DATAFRAME: TextDataFrameInput(self.backbone, max_length=max_length), + InputFormat.LISTS: TextListInput(self.backbone, max_length=max_length), + InputFormat.LABELSTUDIO: LabelStudioTextClassificationInput( backbone=self.backbone, max_length=max_length ), }, - default_data_source=DefaultDataSources.LISTS, + default_input=InputFormat.LISTS, deserializer=TextDeserializer(backbone, max_length), ) @@ -385,8 +381,8 @@ def from_data_frame( Returns: The constructed data module. """ - return cls.from_data_source( - DefaultDataSources.DATAFRAME, + return cls.from_input( + InputFormat.DATAFRAME, (train_data_frame, input_field, target_fields), (val_data_frame, input_field, target_fields), (test_data_frame, input_field, target_fields), @@ -463,8 +459,8 @@ def from_lists( Returns: The constructed data module. """ - return cls.from_data_source( - DefaultDataSources.LISTS, + return cls.from_input( + InputFormat.LISTS, (train_data, train_targets), (val_data, val_targets), (test_data, test_targets), @@ -504,8 +500,8 @@ def from_parquet( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given PARQUET files using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.PARQUET` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.PARQUET` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -549,8 +545,8 @@ def from_parquet( }, ) """ - return cls.from_data_source( - DefaultDataSources.PARQUET, + return cls.from_input( + InputFormat.PARQUET, (train_file, input_field, target_fields), (val_file, input_field, target_fields), (test_file, input_field, target_fields), @@ -622,8 +618,8 @@ def from_hf_datasets( Returns: The constructed data module. """ - return cls.from_data_source( - DefaultDataSources.HUGGINGFACE_DATASET, + return cls.from_input( + InputFormat.HUGGINGFACE_DATASET, (train_hf_dataset, input_field, target_fields), (val_hf_dataset, input_field, target_fields), (test_hf_dataset, input_field, target_fields), diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 950b5dc902..5650334372 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -19,7 +19,7 @@ from pytorch_lightning import Callback from flash.core.classification import ClassificationTask, Labels -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE @@ -108,7 +108,7 @@ def to_metrics_format(self, x) -> torch.Tensor: return super().to_metrics_format(x) def step(self, batch, batch_idx, metrics) -> dict: - target = batch.pop(DefaultDataKeys.TARGET) + target = batch.pop(DataKeys.TARGET) batch = (batch, target) return super().step(batch, batch_idx, metrics) diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 00ea22a992..344b935982 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -29,7 +29,7 @@ import flash from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.io.input import DataKeys, Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.properties import ProcessState @@ -42,7 +42,7 @@ from transformers import AutoTokenizer, DataCollatorWithPadding, default_data_collator -class QuestionAnsweringDataSource(DataSource): +class QuestionAnsweringInput(Input): @requires("text") def __init__( self, @@ -106,14 +106,14 @@ def _tokenize_fn(self, samples: Any) -> Callable: contexts = tokenized_samples.pop("context") answers = tokenized_samples.pop("answer") - tokenized_samples[DefaultDataKeys.METADATA] = [] + tokenized_samples[DataKeys.METADATA] = [] for offset_mapping, example_id, context in zip(offset_mappings, example_ids, contexts): - tokenized_samples[DefaultDataKeys.METADATA].append( + tokenized_samples[DataKeys.METADATA].append( {"context": context, "offset_mapping": offset_mapping, "example_id": example_id} ) if self._running_stage.evaluating: for index, answer in enumerate(answers): - tokenized_samples[DefaultDataKeys.METADATA][index]["answer"] = answer + tokenized_samples[DataKeys.METADATA][index]["answer"] = answer del offset_mappings del example_ids @@ -238,7 +238,7 @@ def doc_stride(self) -> str: return self._doc_stride -class QuestionAnsweringFileDataSource(QuestionAnsweringDataSource): +class QuestionAnsweringFileInput(QuestionAnsweringInput): def __init__( self, filetype: str, @@ -344,7 +344,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) -class QuestionAnsweringCSVDataSource(QuestionAnsweringFileDataSource): +class QuestionAnsweringCSVInput(QuestionAnsweringFileInput): def __init__( self, backbone: str, @@ -378,7 +378,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) -class QuestionAnsweringJSONDataSource(QuestionAnsweringFileDataSource): +class QuestionAnsweringJSONInput(QuestionAnsweringFileInput): def __init__( self, backbone: str, @@ -412,7 +412,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) -class QuestionAnsweringDictionaryDataSource(QuestionAnsweringDataSource): +class QuestionAnsweringDictionaryInput(QuestionAnsweringInput): def load_data(self, data: Any, columns: List[str] = None) -> "datasets.Dataset": stage = self._running_stage.value @@ -434,7 +434,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) -class SQuADDataSource(QuestionAnsweringDataSource): +class SQuADInput(QuestionAnsweringInput): def load_data(self, data: str, dataset: Optional[Any] = None) -> "datasets.Dataset": stage = self._running_stage.value @@ -521,8 +521,8 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: QuestionAnsweringCSVDataSource( + inputs={ + InputFormat.CSV: QuestionAnsweringCSVInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -532,7 +532,7 @@ def __init__( answer_column_name=answer_column_name, doc_stride=doc_stride, ), - DefaultDataSources.JSON: QuestionAnsweringJSONDataSource( + InputFormat.JSON: QuestionAnsweringJSONInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -542,7 +542,7 @@ def __init__( answer_column_name=answer_column_name, doc_stride=doc_stride, ), - "dict": QuestionAnsweringDictionaryDataSource( + "dict": QuestionAnsweringDictionaryInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -552,7 +552,7 @@ def __init__( answer_column_name=answer_column_name, doc_stride=doc_stride, ), - "squad_v2": SQuADDataSource( + "squad_v2": SQuADInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -560,7 +560,7 @@ def __init__( doc_stride=doc_stride, ), }, - default_data_source="dict", + default_input="dict", ) self.set_state(QuestionAnsweringBackboneState(self.backbone)) @@ -691,7 +691,7 @@ def from_squad_v2( doc_stride=128, ) """ - return cls.from_data_source( + return cls.from_input( "squad_v2", train_file, val_file, @@ -728,8 +728,8 @@ def from_json( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.text.question_answering.QuestionAnsweringData` object from the given JSON files - using the :class:`~flash.text.question_answering.QuestionAnsweringDataSource`of name - :attr:`~flash.core.data.data_source.DefaultDataSources.JSON` from the passed or constructed + using the :class:`~flash.text.question_answering.QuestionAnsweringInput`of name + :attr:`~flash.core.data.io.input.InputFormat.JSON` from the passed or constructed :class:`~flash.text.question_answering.QuestionAnsweringInputTransform`. Args: @@ -789,8 +789,8 @@ def from_json( doc_stride=128 ) """ - return cls.from_data_source( - DefaultDataSources.JSON, + return cls.from_input( + InputFormat.JSON, (train_file, field), (val_file, field), (test_file, field), @@ -828,8 +828,8 @@ def from_csv( **input_transform_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the - :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` + :class:`~flash.core.data.io.input.Input` + of name :attr:`~flash.core.data.io.input.InputFormat.CSV` from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`. Args: @@ -892,8 +892,8 @@ def from_csv( doc_stride=128 ) """ - return cls.from_data_source( - DefaultDataSources.CSV, + return cls.from_input( + InputFormat.CSV, train_file, val_file, test_file, diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 1c17b0eca5..6e971dcaf7 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -27,7 +27,7 @@ from pytorch_lightning.utilities import rank_zero_info from torch import Tensor -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry @@ -242,14 +242,14 @@ def _generate_answers(self, pred_start_logits, pred_end_logits, examples): return all_predictions def forward(self, batch: Any) -> Any: - metadata = batch.pop(DefaultDataKeys.METADATA) + metadata = batch.pop(DataKeys.METADATA) outputs = self.model(**batch) loss = outputs.loss start_logits = outputs.start_logits end_logits = outputs.end_logits generated_answers = self._generate_answers(start_logits, end_logits, metadata) - batch[DefaultDataKeys.METADATA] = metadata + batch[DataKeys.METADATA] = metadata return loss, generated_answers def training_step(self, batch: Any, batch_idx: int) -> Tensor: @@ -260,7 +260,7 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor: def common_step(self, prefix: str, batch: Any) -> torch.Tensor: loss, generated_answers = self(batch) - result = self.compute_metrics(generated_answers, batch[DefaultDataKeys.METADATA]) + result = self.compute_metrics(generated_answers, batch[DataKeys.METADATA]) self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True) self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 2afbea8790..c4f6d1958f 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -20,7 +20,7 @@ import flash from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataSources +from flash.core.data.io.input import Input, InputFormat from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.data.properties import ProcessState @@ -33,7 +33,7 @@ from transformers import AutoTokenizer, default_data_collator -class Seq2SeqDataSource(DataSource): +class Seq2SeqInput(Input): @requires("text") def __init__( self, @@ -94,7 +94,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True, **self.backbone_kwargs) -class Seq2SeqFileDataSource(Seq2SeqDataSource): +class Seq2SeqFileInput(Seq2SeqInput): def __init__( self, filetype: str, @@ -162,7 +162,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True, **self.backbone_kwargs) -class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): +class Seq2SeqCSVInput(Seq2SeqFileInput): def __init__( self, backbone: str, @@ -190,7 +190,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True, **self.backbone_kwargs) -class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): +class Seq2SeqJSONInput(Seq2SeqFileInput): def __init__( self, backbone: str, @@ -218,7 +218,7 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True, **self.backbone_kwargs) -class Seq2SeqSentencesDataSource(Seq2SeqDataSource): +class Seq2SeqSentencesInput(Seq2SeqInput): def load_data( self, data: Union[str, List[str]], @@ -273,22 +273,22 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: Seq2SeqCSVDataSource( + inputs={ + InputFormat.CSV: Seq2SeqCSVInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, **backbone_kwargs, ), - DefaultDataSources.JSON: Seq2SeqJSONDataSource( + InputFormat.JSON: Seq2SeqJSONInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, **backbone_kwargs, ), - "sentences": Seq2SeqSentencesDataSource( + "sentences": Seq2SeqSentencesInput( self.backbone, max_source_length=max_source_length, max_target_length=max_target_length, @@ -296,7 +296,7 @@ def __init__( **backbone_kwargs, ), }, - default_data_source="sentences", + default_input="sentences", deserializer=TextDeserializer(backbone, max_source_length), ) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 060ef5fab1..50518b069d 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -151,7 +151,7 @@ def _initialize_model_specific_parameters(self): @property def tokenizer(self) -> "PreTrainedTokenizerBase": - return self.data_pipeline.data_source.tokenizer + return self.data_pipeline.input.tokenizer def tokenize_labels(self, labels: Tensor) -> List[str]: label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True) diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 2d72b3ad45..62ce652853 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -20,15 +20,9 @@ from torch.utils.data import Sampler from flash.core.data.data_module import DataModule -from flash.core.data.data_source import ( - DefaultDataKeys, - DefaultDataSources, - FiftyOneDataSource, - LabelsState, - PathsDataSource, -) +from flash.core.data.io.input import DataKeys, FiftyOneInput, InputFormat, LabelsState, PathsInput from flash.core.data.io.input_transform import InputTransform -from flash.core.integrations.labelstudio.data_source import LabelStudioVideoClassificationDataSource +from flash.core.integrations.labelstudio.input import LabelStudioVideoClassificationInput from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import SampleCollection = None @@ -82,9 +76,9 @@ def load_sample(self, sample): return sample def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - video_path = sample[DefaultDataKeys.INPUT] + video_path = sample[DataKeys.INPUT] sample.update(self._encoded_video_to_dict(EncodedVideo.from_path(video_path))) - sample[DefaultDataKeys.METADATA] = {"filepath": video_path} + sample[DataKeys.METADATA] = {"filepath": video_path} return sample def _encoded_video_to_dict(self, video, annotation: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: @@ -122,7 +116,7 @@ def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset": raise NotImplementedError("Subclass must implement _make_encoded_video_dataset()") -class VideoClassificationPathsDataSource(BaseVideoClassification, PathsDataSource): +class VideoClassificationPathsInput(BaseVideoClassification, PathsInput): def __init__( self, clip_sampler: "ClipSampler", @@ -136,7 +130,7 @@ def __init__( decode_audio=decode_audio, decoder=decoder, ) - PathsDataSource.__init__( + PathsInput.__init__( self, extensions=("mp4", "avi"), ) @@ -152,7 +146,7 @@ def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset": return ds -class VideoClassificationListDataSource(BaseVideoClassification, PathsDataSource): +class VideoClassificationListInput(BaseVideoClassification, PathsInput): def __init__( self, clip_sampler: "ClipSampler", @@ -166,7 +160,7 @@ def __init__( decode_audio=decode_audio, decoder=decoder, ) - PathsDataSource.__init__( + PathsInput.__init__( self, extensions=("mp4", "avi"), ) @@ -222,9 +216,9 @@ def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDa return ds -class VideoClassificationFiftyOneDataSource( +class VideoClassificationFiftyOneInput( BaseVideoClassification, - FiftyOneDataSource, + FiftyOneInput, ): def __init__( self, @@ -240,7 +234,7 @@ def __init__( decode_audio=decode_audio, decoder=decoder, ) - FiftyOneDataSource.__init__( + FiftyOneInput.__init__( self, label_field=label_field, ) @@ -282,7 +276,7 @@ def __init__( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", - **data_source_kwargs: Any, + **_kwargs: Any, ): self.clip_sampler = clip_sampler self.clip_duration = clip_duration @@ -309,35 +303,35 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={ - DefaultDataSources.FILES: VideoClassificationListDataSource( + inputs={ + InputFormat.FILES: VideoClassificationListInput( clip_sampler, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, ), - DefaultDataSources.FOLDERS: VideoClassificationPathsDataSource( + InputFormat.FOLDERS: VideoClassificationPathsInput( clip_sampler, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, ), - DefaultDataSources.FIFTYONE: VideoClassificationFiftyOneDataSource( + InputFormat.FIFTYONE: VideoClassificationFiftyOneInput( clip_sampler, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, - **data_source_kwargs, + **_kwargs, ), - DefaultDataSources.LABELSTUDIO: LabelStudioVideoClassificationDataSource( + InputFormat.LABELSTUDIO: LabelStudioVideoClassificationInput( clip_sampler=clip_sampler, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, - **data_source_kwargs, + **_kwargs, ), }, - default_data_source=DefaultDataSources.FILES, + default_input=InputFormat.FILES, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 20f2890fbd..273766e000 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -27,7 +27,7 @@ import flash from flash.core.classification import ClassificationTask, Labels -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE @@ -168,7 +168,7 @@ def forward(self, x: Any) -> Any: def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: predictions = self(batch["video"]) - batch[DefaultDataKeys.PREDS] = predictions + batch[DataKeys.PREDS] = predictions return batch def configure_finetune_callback(self) -> List[Callback]: diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index fed306ee46..44f5041c27 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -22,8 +22,8 @@ from torch.utils.data._utils.collate import default_collate from flash import _PACKAGE_ROOT, FlashDataset -from flash.core.data.data_source import DefaultDataKeys from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform +from flash.core.data.io.input import DataKeys from flash.core.data.new_data_module import DataModule from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import download_data @@ -54,7 +54,7 @@ # If you use FlashDataset outside of Flash, the only requirements are to return a Sequence # # from load_data with FlashDataset or an Iterable with FlashIterableDataset. # # When using FlashDataset with Flash Tasks, the model expects the `load_sample` to return a # -# dictionary with `DefaultDataKeys` as its keys (c.f `input`, `target`, metadata) # +# dictionary with `DataKeys` as its keys (c.f `input`, `target`, metadata) # # # ############################################################################################# @@ -65,23 +65,23 @@ class MultipleFoldersImageDataset(FlashDataset): - def load_data(self, folders: List[str]) -> List[Dict[DefaultDataKeys, Any]]: + def load_data(self, folders: List[str]) -> List[Dict[DataKeys, Any]]: if self.training: self.num_classes = len(folders) return [ - {DefaultDataKeys.INPUT: os.path.join(folder, p), DefaultDataKeys.TARGET: class_idx} + {DataKeys.INPUT: os.path.join(folder, p), DataKeys.TARGET: class_idx} for class_idx, folder in enumerate(folders) for p in os.listdir(folder) ] - def load_sample(self, sample: Dict[DefaultDataKeys, Any]) -> Dict[DefaultDataKeys, Any]: - sample[DefaultDataKeys.INPUT] = image = Image.open(sample[DefaultDataKeys.INPUT]) - sample[DefaultDataKeys.METADATA] = image.size + def load_sample(self, sample: Dict[DataKeys, Any]) -> Dict[DataKeys, Any]: + sample[DataKeys.INPUT] = image = Image.open(sample[DataKeys.INPUT]) + sample[DataKeys.METADATA] = image.size return sample - def predict_load_data(self, predict_folder: str) -> List[Dict[DefaultDataKeys, Any]]: + def predict_load_data(self, predict_folder: str) -> List[Dict[DataKeys, Any]]: assert os.path.isdir(predict_folder) - return [{DefaultDataKeys.INPUT: os.path.join(predict_folder, p)} for p in os.listdir(predict_folder)] + return [{DataKeys.INPUT: os.path.join(predict_folder, p)} for p in os.listdir(predict_folder)] train_dataset = MultipleFoldersImageDataset.from_train_data(TRAIN_FOLDERS) @@ -101,7 +101,7 @@ def predict_load_data(self, predict_folder: str) -> List[Dict[DefaultDataKeys, A class BaseImageInputTransform(InputTransform): def configure_per_sample_transform(self, image_size: int = 224) -> Any: per_sample_transform = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()]) - return ApplyToKeys(DefaultDataKeys.INPUT, per_sample_transform) + return ApplyToKeys(DataKeys.INPUT, per_sample_transform) def configure_collate(self) -> Any: return default_collate @@ -112,7 +112,7 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float transforms = [T.Resize((image_size, image_size)), T.ToTensor()] if self.training: transforms += [T.RandomRotation(rotation)] - return ApplyToKeys(DefaultDataKeys.INPUT, T.Compose(transforms)) + return ApplyToKeys(DataKeys.INPUT, T.Compose(transforms)) # Register your transform within the Flash Dataset registry @@ -175,9 +175,9 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float print(train_dataset[0]) # Out: # { -# : , -# : 0, -# : (500, 375) +# : , +# : 0, +# : (500, 375) # } ############################################################################################# @@ -215,16 +215,16 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float print(datamodule.train_dataset[0]) # Out: # { -# : , -# : 0, -# : (500, 375) +# : , +# : 0, +# : (500, 375) # } assert isinstance(datamodule.predict_dataset, FlashDataset) print(datamodule.predict_dataset[0]) # out: # { -# {: 'data/hymenoptera_data/train/ants/957233405_25c1d1187b.jpg'} +# {: 'data/hymenoptera_data/train/ants/957233405_25c1d1187b.jpg'} # } @@ -232,9 +232,9 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float batch = next(iter(datamodule.train_dataloader())) # Out: # { -# : tensor([...]), -# : tensor([...]), -# : [(...), (...), ...], +# : tensor([...]), +# : tensor([...]), +# : [(...), (...), ...], # } print(batch) @@ -297,8 +297,8 @@ def from_multiple_folders( batch = next(iter(datamodule.train_dataloader())) # Out: # { -# : tensor([...]), -# : tensor([...]), -# : [(...), (...), ...], +# : tensor([...]), +# : tensor([...]), +# : [(...), (...), ...], # } print(batch) diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index af6b51feb8..00890af201 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -24,7 +24,7 @@ from torch import nn import flash -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate from flash.image import ImageClassificationData, ImageClassifier @@ -37,11 +37,11 @@ train_transform = { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), "post_tensor_transform": ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, Kg.Resize((196, 196)), # SPATIAL Ka.RandomHorizontalFlip(p=0.25), @@ -58,7 +58,7 @@ ), "collate": kornia_collate, "per_batch_transform_on_device": ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, Ka.RandomHorizontalFlip(p=0.25), ), } diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index 44010e7af1..d43712cb28 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -20,7 +20,7 @@ import torch.nn as nn from flash.audio import AudioClassificationData -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from tests.helpers.utils import _AUDIO_TESTING @@ -275,8 +275,8 @@ def test_from_filepaths_splits(tmpdir): _to_tensor = { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), } diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py index 30a069e0d9..86ee97a7a9 100644 --- a/tests/audio/speech_recognition/test_data.py +++ b/tests/audio/speech_recognition/test_data.py @@ -19,7 +19,7 @@ import flash from flash.audio import SpeechRecognitionData -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _AUDIO_AVAILABLE from tests.helpers.utils import _AUDIO_TESTING @@ -54,8 +54,8 @@ def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = SpeechRecognitionData.from_csv("file", "text", train_file=csv_path, batch_size=1, num_workers=0) batch = next(iter(dm.train_dataloader())) - assert DefaultDataKeys.INPUT in batch - assert DefaultDataKeys.TARGET in batch + assert DataKeys.INPUT in batch + assert DataKeys.TARGET in batch @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -66,12 +66,12 @@ def test_stage_test_and_valid(tmpdir): "file", "text", train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, num_workers=0 ) batch = next(iter(dm.val_dataloader())) - assert DefaultDataKeys.INPUT in batch - assert DefaultDataKeys.TARGET in batch + assert DataKeys.INPUT in batch + assert DataKeys.TARGET in batch batch = next(iter(dm.test_dataloader())) - assert DefaultDataKeys.INPUT in batch - assert DefaultDataKeys.TARGET in batch + assert DataKeys.INPUT in batch + assert DataKeys.TARGET in batch @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -80,8 +80,8 @@ def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0) batch = next(iter(dm.train_dataloader())) - assert DefaultDataKeys.INPUT in batch - assert DefaultDataKeys.TARGET in batch + assert DataKeys.INPUT in batch + assert DataKeys.TARGET in batch @pytest.mark.skipif(_AUDIO_AVAILABLE, reason="audio libraries are installed.") diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 6fe31b83c9..19299ecc60 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -23,7 +23,7 @@ from flash.__main__ import main from flash.audio import SpeechRecognition from flash.audio.speech_recognition.data import SpeechRecognitionInputTransform, SpeechRecognitionOutputTransform -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _AUDIO_AVAILABLE from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING @@ -33,9 +33,9 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { - DefaultDataKeys.INPUT: np.random.randn(86631), - DefaultDataKeys.TARGET: "some target text", - DefaultDataKeys.METADATA: {"sampling_rate": 16000}, + DataKeys.INPUT: np.random.randn(86631), + DataKeys.TARGET: "some target text", + DataKeys.METADATA: {"sampling_rate": 16000}, } def __len__(self) -> int: diff --git a/tests/core/data/io/test_input.py b/tests/core/data/io/test_input.py new file mode 100644 index 0000000000..0b50c7a920 --- /dev/null +++ b/tests/core/data/io/test_input.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.data.io.input import DataKeys, DatasetInput + + +def test_dataset_input(): + dataset_input = DatasetInput() + + input, target = "test", 3 + + assert dataset_input.load_sample((input, target)) == {DataKeys.INPUT: input, DataKeys.TARGET: target} + assert dataset_input.load_sample(input) == {DataKeys.INPUT: input} diff --git a/tests/core/data/io/test_input_transform.py b/tests/core/data/io/test_input_transform.py index 92513bfaac..08ccc76cf5 100644 --- a/tests/core/data/io/test_input_transform.py +++ b/tests/core/data/io/test_input_transform.py @@ -19,7 +19,7 @@ from torch.utils.data._utils.collate import default_collate from flash import DataModule -from flash.core.data.data_source import DefaultDataSources +from flash.core.data.io.input import InputFormat from flash.core.data.io.input_transform import ( _InputTransformProcessor, _InputTransformSequential, @@ -31,11 +31,11 @@ class CustomInputTransform(DefaultInputTransform): def __init__(self): super().__init__( - data_sources={ + inputs={ "test": Mock(return_value="test"), - DefaultDataSources.TENSORS: Mock(return_value="tensors"), + InputFormat.TENSORS: Mock(return_value="tensors"), }, - default_data_source="test", + default_input="test", ) @@ -79,30 +79,30 @@ def test_sequential_str(): ) -def test_data_source_of_name(): +def test_input_of_name(): input_transform = CustomInputTransform() - assert input_transform.data_source_of_name("test")() == "test" - assert input_transform.data_source_of_name(DefaultDataSources.TENSORS)() == "tensors" - assert input_transform.data_source_of_name("tensors")() == "tensors" - assert input_transform.data_source_of_name("default")() == "test" + assert input_transform.input_of_name("test")() == "test" + assert input_transform.input_of_name(InputFormat.TENSORS)() == "tensors" + assert input_transform.input_of_name("tensors")() == "tensors" + assert input_transform.input_of_name("default")() == "test" with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"): - input_transform.data_source_of_name("not available") + input_transform.input_of_name("not available") -def test_available_data_sources(): +def test_available_inputs(): input_transform = CustomInputTransform() - assert DefaultDataSources.TENSORS in input_transform.available_data_sources() - assert "test" in input_transform.available_data_sources() - assert len(input_transform.available_data_sources()) == 3 + assert InputFormat.TENSORS in input_transform.available_inputs() + assert "test" in input_transform.available_inputs() + assert len(input_transform.available_inputs()) == 3 data_module = DataModule(input_transform=input_transform) - assert DefaultDataSources.TENSORS in data_module.available_data_sources() - assert "test" in data_module.available_data_sources() - assert len(data_module.available_data_sources()) == 3 + assert InputFormat.TENSORS in data_module.available_inputs() + assert "test" in data_module.available_inputs() + assert len(data_module.available_inputs()) == 3 def test_check_transforms(): diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index a0890bfb04..c7626a1b6e 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -19,7 +19,7 @@ from flash.core.classification import Labels from flash.core.data.data_pipeline import DataPipeline, DataPipelineState -from flash.core.data.data_source import LabelsState +from flash.core.data.io.input import LabelsState from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output import Output from flash.core.model import Task diff --git a/tests/core/data/test_auto_dataset.py b/tests/core/data/test_auto_dataset.py index 7c65a160b5..8e0db9c14c 100644 --- a/tests/core/data/test_auto_dataset.py +++ b/tests/core/data/test_auto_dataset.py @@ -17,11 +17,11 @@ from flash.core.data.auto_dataset import AutoDataset, BaseAutoDataset, IterableAutoDataset from flash.core.data.callback import FlashCallback -from flash.core.data.data_source import DataSource +from flash.core.data.io.input import Input from flash.core.utilities.stages import RunningStage -class _AutoDatasetTestDataSource(DataSource): +class _AutoDatasetTestInput(Input): def __init__(self, with_dset: bool): self._callbacks: List[FlashCallback] = [] self.load_data_count = 0 @@ -87,14 +87,14 @@ def train_load_data_with_dataset(self, data, dataset): @pytest.mark.parametrize("running_stage", [RunningStage.TRAINING, RunningStage.TESTING, RunningStage.VALIDATING]) def test_base_autodataset_smoke(running_stage): dt = range(10) - ds = DataSource() - dset = BaseAutoDataset(data=dt, data_source=ds, running_stage=running_stage) + ds = Input() + dset = BaseAutoDataset(data=dt, input=ds, running_stage=running_stage) assert dset is not None assert dset.running_stage == running_stage # check on members assert dset.data == dt - assert dset.data_source == ds + assert dset.input == ds # test set the running stage dset.running_stage = RunningStage.PREDICTING @@ -108,15 +108,15 @@ def test_base_autodataset_smoke(running_stage): def test_autodataset_smoke(): num_samples = 20 dt = range(num_samples) - ds = DataSource() + ds = Input() - dset = AutoDataset(data=dt, data_source=ds, running_stage=RunningStage.TRAINING) + dset = AutoDataset(data=dt, input=ds, running_stage=RunningStage.TRAINING) assert dset is not None assert dset.running_stage == RunningStage.TRAINING # check on members assert dset.data == dt - assert dset.data_source == ds + assert dset.input == ds # test set the running stage dset.running_stage = RunningStage.PREDICTING @@ -136,15 +136,15 @@ def test_autodataset_smoke(): def test_iterable_autodataset_smoke(): num_samples = 20 dt = range(num_samples) - ds = DataSource() + ds = Input() - dset = IterableAutoDataset(data=dt, data_source=ds, running_stage=RunningStage.TRAINING) + dset = IterableAutoDataset(data=dt, input=ds, running_stage=RunningStage.TRAINING) assert dset is not None assert dset.running_stage == RunningStage.TRAINING # check on members assert dset.data == dt - assert dset.data_source == ds + assert dset.input == ds # test set the running stage dset.running_stage = RunningStage.PREDICTING @@ -168,11 +168,11 @@ def test_iterable_autodataset_smoke(): False, ], ) -def test_input_transforming_data_source_with_running_stage(with_dataset): - data_source = _AutoDatasetTestDataSource(with_dataset) +def test_input_transforming_input_with_running_stage(with_dataset): + input = _AutoDatasetTestInput(with_dataset) running_stage = RunningStage.TRAINING - dataset = data_source.generate_dataset(range(10), running_stage=running_stage) + dataset = input.generate_dataset(range(10), running_stage=running_stage) assert len(dataset) == 10 @@ -182,8 +182,8 @@ def test_input_transforming_data_source_with_running_stage(with_dataset): if with_dataset: assert dataset.train_load_sample_was_called assert dataset.train_load_data_was_called - assert data_source.train_load_sample_with_dataset_count == len(dataset) - assert data_source.train_load_data_with_dataset_count == 1 + assert input.train_load_sample_with_dataset_count == len(dataset) + assert input.train_load_data_with_dataset_count == 1 else: - assert data_source.train_load_sample_count == len(dataset) - assert data_source.train_load_data_count == 1 + assert input.train_load_sample_count == len(dataset) + assert input.train_load_data_count == 1 diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index d4b2ce34be..0a865a14de 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -21,7 +21,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.data.base_viz import BaseVisualization -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.utils import _CALLBACK_FUNCS, _STAGES_PREFIX from flash.core.utilities.imports import _PIL_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -118,7 +118,7 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: is_predict = stage == "predict" def _extract_data(data): - return data[0][DefaultDataKeys.INPUT] + return data[0][DataKeys.INPUT] def _get_result(function_name: str): return dm.data_fetcher.batches[stage][function_name] @@ -129,7 +129,7 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("load_sample") - assert isinstance(res[0][DefaultDataKeys.TARGET], int) + assert isinstance(res[0][DataKeys.TARGET], int) res = _get_result("to_tensor_transform") assert len(res) == B @@ -137,21 +137,21 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("to_tensor_transform") - assert isinstance(res[0][DefaultDataKeys.TARGET], torch.Tensor) + assert isinstance(res[0][DataKeys.TARGET], torch.Tensor) res = _get_result("collate") assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("collate") - assert res[0][DefaultDataKeys.TARGET].shape == torch.Size([2]) + assert res[0][DataKeys.TARGET].shape == torch.Size([2]) res = _get_result("per_batch_transform") assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("per_batch_transform") - assert res[0][DefaultDataKeys.TARGET].shape == (B,) + assert res[0][DataKeys.TARGET].shape == (B,) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index 577c84ecd5..b9245d2760 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -31,7 +31,7 @@ def test_flash_callback(_, __, tmpdir): callback_mock = MagicMock() inputs = [[torch.rand(1), torch.rand(1)]] - dm = DataModule.from_data_source( + dm = DataModule.from_input( "default", inputs, inputs, inputs, None, input_transform=DefaultInputTransform(), batch_size=1, num_workers=0 ) dm.input_transform.callbacks += [callback_mock] @@ -58,7 +58,7 @@ def __init__(self): limit_train_batches=1, progress_bar_refresh_rate=0, ) - dm = DataModule.from_data_source( + dm = DataModule.from_input( "default", inputs, inputs, inputs, None, input_transform=DefaultInputTransform(), batch_size=1, num_workers=0 ) dm.input_transform.callbacks += [callback_mock] diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index d61a591c94..bac9e660f7 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -45,7 +45,7 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat input_transform = DefaultInputTransform() - return cls.from_data_source( + return cls.from_input( "default", train_data=train_data, val_data=val_data, diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 268850729c..e0c6a0d6b2 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -27,7 +27,7 @@ from flash.core.data.auto_dataset import IterableAutoDataset from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import _StageOrchestrator, DataPipeline, DataPipelineState -from flash.core.data.data_source import DataSource +from flash.core.data.io.input import Input from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform, InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform @@ -72,14 +72,14 @@ def test_get_state(): def test_data_pipeline_str(): data_pipeline = DataPipeline( - data_source=cast(DataSource, "data_source"), + input=cast(Input, "input"), input_transform=cast(InputTransform, "input_transform"), output_transform=cast(OutputTransform, "output_transform"), output=cast(Output, "output"), deserializer=cast(Deserializer, "deserializer"), ) - expected = "data_source=data_source, deserializer=deserializer, " + expected = "input=input, deserializer=deserializer, " expected += "input_transform=input_transform, output_transform=output_transform, output=output" assert str(data_pipeline) == (f"DataPipeline({expected})") @@ -533,7 +533,7 @@ def __len__(self) -> int: return 5 -class TestInputTransformationsDataSource(DataSource): +class TestInputTransformationsInput(Input): def __init__(self): super().__init__() @@ -593,7 +593,7 @@ def predict_load_data(self, sample) -> LamdaDummyDataset: class TestInputTransformations(DefaultInputTransform): def __init__(self): - super().__init__(data_sources={"default": TestInputTransformationsDataSource()}) + super().__init__(inputs={"default": TestInputTransformationsInput()}) self.train_pre_tensor_transform_called = False self.train_collate_called = False @@ -691,7 +691,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def test_datapipeline_transformations(tmpdir): - datamodule = DataModule.from_data_source( + datamodule = DataModule.from_input( "default", 1, 1, 1, 1, batch_size=2, num_workers=0, input_transform=TestInputTransformations() ) @@ -704,7 +704,7 @@ def test_datapipeline_transformations(tmpdir): with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): batch = next(iter(datamodule.val_dataloader())) - datamodule = DataModule.from_data_source( + datamodule = DataModule.from_input( "default", 1, 1, 1, 1, batch_size=2, num_workers=0, input_transform=TestInputTransformations2() ) batch = next(iter(datamodule.val_dataloader())) @@ -725,26 +725,26 @@ def test_datapipeline_transformations(tmpdir): trainer.predict(model) input_transform = model._input_transform - data_source = input_transform.data_source_of_name("default") - assert data_source.train_load_data_called + input = input_transform.input_of_name("default") + assert input.train_load_data_called assert input_transform.train_pre_tensor_transform_called assert input_transform.train_collate_called assert input_transform.train_per_batch_transform_on_device_called - assert data_source.val_load_data_called - assert data_source.val_load_sample_called + assert input.val_load_data_called + assert input.val_load_sample_called assert input_transform.val_to_tensor_transform_called assert input_transform.val_collate_called assert input_transform.val_per_batch_transform_on_device_called - assert data_source.test_load_data_called + assert input.test_load_data_called assert input_transform.test_to_tensor_transform_called assert input_transform.test_post_tensor_transform_called - assert data_source.predict_load_data_called + assert input.predict_load_data_called @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_datapipeline_transformations_overridden_by_task(): # define input transforms - class ImageDataSource(DataSource): + class ImageInput(Input): def load_data(self, folder: str): # from folder -> return files paths return ["a.jpg", "b.jpg"] @@ -766,7 +766,7 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={"default": ImageDataSource()}, + inputs={"default": ImageInput()}, ) def default_transforms(self): @@ -800,7 +800,7 @@ class CustomDataModule(DataModule): input_transform_cls = ImageClassificationInputTransform - datamodule = CustomDataModule.from_data_source( + datamodule = CustomDataModule.from_input( "default", "train_folder", "val_folder", @@ -841,7 +841,7 @@ def val_collate(self, *_): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): - class ImageDataSource(DataSource): + class ImageInput(Input): def load_data(self, folder: str): # from folder -> return files paths return ["a.jpg", "b.jpg"] @@ -866,7 +866,7 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={"default": ImageDataSource()}, + inputs={"default": ImageInput()}, ) self._to_tensor = to_tensor_transform self._train_per_sample_transform_on_device = train_per_sample_transform_on_device @@ -896,7 +896,7 @@ class CustomDataModule(DataModule): input_transform_cls = ImageClassificationInputTransform - datamodule = CustomDataModule.from_data_source( + datamodule = CustomDataModule.from_input( "default", "train_folder", "val_folder", @@ -1008,11 +1008,11 @@ def per_batch_transform(self, batch: Any) -> Any: def test_iterable_auto_dataset(tmpdir): - class CustomDataSource(DataSource): + class CustomInput(Input): def load_sample(self, index: int) -> Dict[str, int]: return {"index": index} - ds = IterableAutoDataset(range(10), data_source=CustomDataSource(), running_stage=RunningStage.TRAINING) + ds = IterableAutoDataset(range(10), input=CustomInput(), running_stage=RunningStage.TRAINING) for index, v in enumerate(ds): assert v == {"index": index} diff --git a/tests/core/data/test_data_source.py b/tests/core/data/test_data_source.py index 24a0b875fc..d030a13712 100644 --- a/tests/core/data/test_data_source.py +++ b/tests/core/data/test_data_source.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys +import pytest +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys -def test_dataset_data_source(): - data_source = DatasetDataSource() - input, target = "test", 3 +def test_default_data_keys_deprecation(): + with pytest.warns(FutureWarning, match="`DefaultDataKeys` was deprecated in 0.6.0"): + _ = DefaultDataKeys.INPUT - assert data_source.load_sample((input, target)) == {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} - assert data_source.load_sample(input) == {DefaultDataKeys.INPUT: input} + assert DefaultDataKeys.INPUT == DataKeys.INPUT diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index b66bd41cc8..7b29e5a922 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -17,7 +17,7 @@ import torch from torch import nn -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms, merge_transforms from flash.core.data.utils import convert_to_modules @@ -26,10 +26,10 @@ class TestApplyToKeys: @pytest.mark.parametrize( "sample, keys, expected", [ - ({DefaultDataKeys.INPUT: "test"}, DefaultDataKeys.INPUT, "test"), + ({DataKeys.INPUT: "test"}, DataKeys.INPUT, "test"), ( - {DefaultDataKeys.INPUT: "test_a", DefaultDataKeys.TARGET: "test_b"}, - [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + {DataKeys.INPUT: "test_a", DataKeys.TARGET: "test_b"}, + [DataKeys.INPUT, DataKeys.TARGET], ["test_a", "test_b"], ), ({"input": "test"}, "input", "test"), @@ -51,13 +51,12 @@ def test_forward(self, sample, keys, expected): "transform, expected", [ ( - ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.ReLU()), - "ApplyToKeys(keys=, transform=ReLU())", + ApplyToKeys(DataKeys.INPUT, torch.nn.ReLU()), + "ApplyToKeys(keys=, transform=ReLU())", ), ( - ApplyToKeys([DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], torch.nn.ReLU()), - "ApplyToKeys(keys=[, " - "], transform=ReLU())", + ApplyToKeys([DataKeys.INPUT, DataKeys.TARGET], torch.nn.ReLU()), + "ApplyToKeys(keys=[, " "], transform=ReLU())", ), (ApplyToKeys("input", torch.nn.ReLU()), "ApplyToKeys(keys='input', transform=ReLU())"), ( @@ -100,15 +99,15 @@ def test_kornia_parallel_transforms(with_params): def test_kornia_collate(): samples = [ - {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 1}, - {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 2}, - {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 3}, + {DataKeys.INPUT: torch.zeros(1, 3, 10, 10), DataKeys.TARGET: 1}, + {DataKeys.INPUT: torch.zeros(1, 3, 10, 10), DataKeys.TARGET: 2}, + {DataKeys.INPUT: torch.zeros(1, 3, 10, 10), DataKeys.TARGET: 3}, ] result = kornia_collate(samples) - assert torch.all(result[DefaultDataKeys.TARGET] == torch.tensor([1, 2, 3])) - assert list(result[DefaultDataKeys.INPUT].shape) == [3, 3, 10, 10] - assert torch.allclose(result[DefaultDataKeys.INPUT], torch.zeros(1)) + assert torch.all(result[DataKeys.TARGET] == torch.tensor([1, 2, 3])) + assert list(result[DataKeys.INPUT].shape) == [3, 3, 10, 10] + assert torch.allclose(result[DataKeys.INPUT], torch.zeros(1)) _MOCK_TRANSFORM = Mock() diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py index 9e04d839ec..8dd6ff82d8 100644 --- a/tests/core/integrations/labelstudio/test_labelstudio.py +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -1,11 +1,11 @@ import pytest -from flash.core.data.data_source import DefaultDataSources +from flash.core.data.io.input import InputFormat from flash.core.data.utils import download_data -from flash.core.integrations.labelstudio.data_source import ( - LabelStudioDataSource, - LabelStudioImageClassificationDataSource, - LabelStudioTextClassificationDataSource, +from flash.core.integrations.labelstudio.input import ( + LabelStudioImageClassificationInput, + LabelStudioInput, + LabelStudioTextClassificationInput, ) from flash.core.integrations.labelstudio.visualizer import launch_app from flash.image.classification.data import ImageClassificationData @@ -125,22 +125,22 @@ def test_utility_load(): "project": 7, } ] - ds = LabelStudioDataSource._load_json_data(data=data, data_folder=".", multi_label=False) + ds = LabelStudioInput._load_json_data(data=data, data_folder=".", multi_label=False) assert ds[3] == {"image"} assert ds[2] == {"Road", "Car", "Obstacle"} assert len(ds[1]) == 0 assert len(ds[0]) == 5 - ds_multi = LabelStudioDataSource._load_json_data(data=data, data_folder=".", multi_label=True) + ds_multi = LabelStudioInput._load_json_data(data=data, data_folder=".", multi_label=True) assert ds_multi[3] == {"image"} assert ds_multi[2] == {"Road", "Car", "Obstacle"} assert len(ds_multi[1]) == 0 assert len(ds_multi[0]) == 5 -def test_datasource_labelstudio(): - """Test creation of LabelStudioDataSource.""" +def test_Input_labelstudio(): + """Test creation of LabelStudioInput.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") - ds = LabelStudioDataSource() + ds = LabelStudioInput() data = { "data_folder": "data/upload/", "export_json": "data/project.json", @@ -154,7 +154,7 @@ def test_datasource_labelstudio(): assert val_sample assert test assert not predict - ds_no_split = LabelStudioDataSource() + ds_no_split = LabelStudioInput() data = { "data_folder": "data/upload/", "export_json": "data/project.json", @@ -166,8 +166,8 @@ def test_datasource_labelstudio(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_datasource_labelstudio_image(): - """Test creation of LabelStudioImageClassificationDataSource from images.""" +def test_Input_labelstudio_image(): + """Test creation of LabelStudioImageClassificationInput from images.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data_nofile.zip") data = { @@ -176,7 +176,7 @@ def test_datasource_labelstudio_image(): "split": 0.2, "multi_label": True, } - ds = LabelStudioImageClassificationDataSource() + ds = LabelStudioImageClassificationInput() train, val, test, predict = ds.to_datasets(train_data=data, val_data=data, test_data=data, predict_data=data) train_sample = train[0] val_sample = val[0] @@ -190,7 +190,7 @@ def test_datasource_labelstudio_image(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_datamodule_labelstudio_image(): - """Test creation of LabelStudioImageClassificationDataSource and Datamodule from images.""" + """Test creation of LabelStudioImageClassificationInput and Datamodule from images.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") datamodule = ImageClassificationData.from_labelstudio( @@ -205,7 +205,7 @@ def test_datamodule_labelstudio_image(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_label_studio_predictions_visualization(): - """Test creation of LabelStudioImageClassificationDataSource and Datamodule from images.""" + """Test creation of LabelStudioImageClassificationInput and Datamodule from images.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") datamodule = ImageClassificationData.from_labelstudio( @@ -229,8 +229,8 @@ def test_label_studio_predictions_visualization(): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_datasource_labelstudio_text(): - """Test creation of LabelStudioTextClassificationDataSource and Datamodule from text.""" +def test_Input_labelstudio_text(): + """Test creation of LabelStudioTextClassificationInput and Datamodule from text.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") backbone = "prajjwal1/bert-medium" data = { @@ -239,7 +239,7 @@ def test_datasource_labelstudio_text(): "split": 0.2, "multi_label": False, } - ds = LabelStudioTextClassificationDataSource(backbone=backbone) + ds = LabelStudioTextClassificationInput(backbone=backbone) train, val, test, predict = ds.to_datasets(train_data=data, test_data=data) train_sample = train[0] test_sample = test[0] @@ -252,7 +252,7 @@ def test_datasource_labelstudio_text(): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_datamodule_labelstudio_text(): - """Test creation of LabelStudioTextClassificationDataSource and Datamodule from text.""" + """Test creation of LabelStudioTextClassificationInput and Datamodule from text.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") backbone = "prajjwal1/bert-medium" datamodule = TextClassificationData.from_labelstudio( @@ -268,12 +268,12 @@ def test_datamodule_labelstudio_text(): @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") -def test_datasource_labelstudio_video(): - """Test creation of LabelStudioVideoClassificationDataSource from video.""" +def test_Input_labelstudio_video(): + """Test creation of LabelStudioVideoClassificationInput from video.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") data = {"data_folder": "data/upload/", "export_json": "data/project.json", "multi_label": True} input_transform = VideoClassificationInputTransform() - ds = input_transform.data_source_of_name(DefaultDataSources.LABELSTUDIO) + ds = input_transform.input_of_name(InputFormat.LABELSTUDIO) train, val, test, predict = ds.to_datasets(train_data=data, test_data=data) sample_iter = iter(train) sample = next(sample_iter) diff --git a/tests/core/integrations/vissl/test_transforms.py b/tests/core/integrations/vissl/test_transforms.py index fa379acda3..0d5ff2a900 100644 --- a/tests/core/integrations/vissl/test_transforms.py +++ b/tests/core/integrations/vissl/test_transforms.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE from tests.image.embedding.utils import ssl_datamodule @@ -35,7 +35,7 @@ def test_multicrop_input_transform(): )._train_dataloader() batch = next(iter(train_dataloader)) - assert len(batch[DefaultDataKeys.INPUT]) == total_num_crops - assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) - assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) - assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size] + assert len(batch[DataKeys.INPUT]) == total_num_crops + assert batch[DataKeys.INPUT][0].shape == (batch_size, 3, size_crops[0], size_crops[0]) + assert batch[DataKeys.INPUT][-1].shape == (batch_size, 3, size_crops[-1], size_crops[-1]) + assert list(batch[DataKeys.TARGET].shape) == [batch_size] diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index a7c6f6f38a..d322c1bba3 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -15,7 +15,7 @@ import torch from flash.core.classification import Classes, FiftyOneLabels, Labels, Logits, Probabilities -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE @@ -47,7 +47,7 @@ def test_classification_outputs_multi_label(): def test_classification_outputs_fiftyone(): logits = torch.tensor([-0.1, 0.2, 0.3]) - example_output = {DefaultDataKeys.PREDS: logits, DefaultDataKeys.METADATA: {"filepath": "something"}} # 3 classes + example_output = {DataKeys.PREDS: logits, DataKeys.METADATA: {"filepath": "something"}} # 3 classes labels = ["class_1", "class_2", "class_3"] predictions = FiftyOneLabels(return_filepath=True).transform(example_output) diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index 271e7ecab5..3a5d7f2c6f 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -74,7 +74,7 @@ def test_predict_dataset(tmpdir): tudataset = datasets.TUDataset(root=tmpdir, name="KKI") model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform()) - out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) + out = model.predict(tudataset, input="datasets", data_pipeline=data_pipe) assert isinstance(out[0], int) diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py index f7c15b1095..c06da323e9 100644 --- a/tests/graph/embedding/test_model.py +++ b/tests/graph/embedding/test_model.py @@ -59,5 +59,5 @@ def test_predict_dataset(tmpdir): GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes).backbone ) data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform()) - out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) + out = model.predict(tudataset, input="datasets", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 4e2a3c79ab..e1d3d501cc 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, merge_transforms from flash.core.utilities.imports import ( _ALBUMENTATIONS_AVAILABLE, @@ -229,8 +229,8 @@ def test_from_filepaths_splits(tmpdir): _to_tensor = { "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), } @@ -467,19 +467,19 @@ def test_from_datasets(): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) @@ -515,7 +515,7 @@ def test_from_csv_single_target(single_target_csv): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) @@ -543,7 +543,7 @@ def test_from_csv_multi_target(multi_target_csv): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 2) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 779da0c646..89ed32ce76 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -21,7 +21,7 @@ from flash import Trainer from flash.__main__ import main from flash.core.classification import Probabilities -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import ImageClassifier from flash.image.classification.data import ImageClassificationInputTransform @@ -33,8 +33,8 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { - DefaultDataKeys.INPUT: torch.rand(3, 224, 224), - DefaultDataKeys.TARGET: torch.randint(10, size=(1,)).item(), + DataKeys.INPUT: torch.rand(3, 224, 224), + DataKeys.TARGET: torch.randint(10, size=(1,)).item(), } def __len__(self) -> int: @@ -47,8 +47,8 @@ def __init__(self, num_classes: int): def __getitem__(self, index): return { - DefaultDataKeys.INPUT: torch.rand(3, 224, 224), - DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes,)), + DataKeys.INPUT: torch.rand(3, 224, 224), + DataKeys.TARGET: torch.randint(0, 2, (self.num_classes,)), } def __len__(self) -> int: @@ -108,8 +108,8 @@ def test_multilabel(tmpdir): train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") - image, label = ds[0][DefaultDataKeys.INPUT], ds[0][DefaultDataKeys.TARGET] - predictions = model.predict([{DefaultDataKeys.INPUT: image}]) + image, label = ds[0][DataKeys.INPUT], ds[0][DataKeys.TARGET] + predictions = model.predict([{DataKeys.INPUT: image}]) assert (torch.tensor(predictions) > 1).sum() == 0 assert (torch.tensor(predictions) < 0).sum() == 0 assert len(predictions[0]) == num_classes == len(label) diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py index 746880b4be..6c6950da09 100644 --- a/tests/image/classification/test_training_strategies.py +++ b/tests/image/classification/test_training_strategies.py @@ -19,7 +19,7 @@ from torch.utils.data import DataLoader from flash import Trainer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.adapters import TRAINING_STRATEGIES @@ -32,8 +32,8 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { - DefaultDataKeys.INPUT: torch.rand(3, 96, 96), - DefaultDataKeys.TARGET: torch.randint(10, size=(1,)).item(), + DataKeys.INPUT: torch.rand(3, 96, 96), + DataKeys.TARGET: torch.randint(10, size=(1,)).item(), } def __len__(self) -> int: diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 3f8d700704..875d0d9711 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -17,7 +17,7 @@ import pytest -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image.detection.data import ObjectDetectionData @@ -154,7 +154,7 @@ def test_image_detector_data_from_coco(tmpdir): data = next(iter(datamodule.train_dataloader())) sample = data[0] - assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) + assert sample[DataKeys.INPUT].shape == (128, 128, 3) datamodule = ObjectDetectionData.from_coco( train_folder=train_folder, @@ -171,11 +171,11 @@ def test_image_detector_data_from_coco(tmpdir): data = next(iter(datamodule.val_dataloader())) sample = data[0] - assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) + assert sample[DataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) sample = data[0] - assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) + assert sample[DataKeys.INPUT].shape == (128, 128, 3) @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -188,7 +188,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): data = next(iter(datamodule.train_dataloader())) sample = data[0] - assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) + assert sample[DataKeys.INPUT].shape == (128, 128, 3) datamodule = ObjectDetectionData.from_fiftyone( train_dataset=train_dataset, @@ -201,8 +201,8 @@ def test_image_detector_data_from_fiftyone(tmpdir): data = next(iter(datamodule.val_dataloader())) sample = data[0] - assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) + assert sample[DataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) sample = data[0] - assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) + assert sample[DataKeys.INPUT].shape == (128, 128, 3) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 3893cdc242..5c4997b151 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -22,7 +22,7 @@ from torch.utils.data import Dataset from flash.__main__ import main -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector @@ -53,16 +53,16 @@ def __getitem__(self, idx): img = np.random.rand(*self.img_shape).astype(np.float32) - sample[DefaultDataKeys.INPUT] = img + sample[DataKeys.INPUT] = img - sample[DefaultDataKeys.TARGET] = { + sample[DataKeys.TARGET] = { "bboxes": [], "labels": [], } for i in range(self.num_boxes): - sample[DefaultDataKeys.TARGET]["bboxes"].append(self._random_bbox()) - sample[DefaultDataKeys.TARGET]["labels"].append(random.randint(0, self.num_classes - 1)) + sample[DataKeys.TARGET]["bboxes"].append(self._random_bbox()) + sample[DataKeys.TARGET]["labels"].append(random.randint(0, self.num_classes - 1)) return sample @@ -78,7 +78,7 @@ def test_init(): dl = model.process_predict_dataset(ds, batch_size=batch_size) data = next(iter(dl)) - out = model.forward(data[DefaultDataKeys.INPUT]) + out = model.forward(data[DataKeys.INPUT]) assert len(out) == batch_size assert all(isinstance(res, dict) for res in out) diff --git a/tests/image/detection/test_output.py b/tests/image/detection/test_output.py index 9023106c02..002fb68f41 100644 --- a/tests/image/detection/test_output.py +++ b/tests/image/detection/test_output.py @@ -2,7 +2,7 @@ import pytest import torch -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image.detection.output import FiftyOneDetectionLabels @@ -24,7 +24,7 @@ def test_serialize_fiftyone(): labels_serial = FiftyOneDetectionLabels(labels=labels) sample = { - DefaultDataKeys.PREDS: { + DataKeys.PREDS: { "bboxes": [ { "xmin": torch.tensor(20), @@ -36,7 +36,7 @@ def test_serialize_fiftyone(): "labels": [torch.tensor(0)], "scores": [torch.tensor(0.5)], }, - DefaultDataKeys.METADATA: { + DataKeys.METADATA: { "filepath": "something", "size": (100, 100), }, diff --git a/tests/image/embedding/utils.py b/tests/image/embedding/utils.py index ee7fe2bd13..a6dc181b22 100644 --- a/tests/image/embedding/utils.py +++ b/tests/image/embedding/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE @@ -38,7 +38,7 @@ def ssl_datamodule( ) to_tensor_transform = ApplyToKeys( - DefaultDataKeys.INPUT, + DataKeys.INPUT, multi_crop_transform, ) input_transform = DefaultInputTransform( diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index 68e5e3f758..1cd8d061c3 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -8,7 +8,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash import Trainer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationInputTransform from tests.helpers.utils import _IMAGE_TESTING @@ -108,19 +108,19 @@ def test_from_folders(tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check val data data = next(iter(dm.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check test data data = next(iter(dm.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) @@ -163,7 +163,7 @@ def test_from_folders_warning(tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (1, 3, 128, 128) assert labels.shape == (1, 128, 128) @@ -210,19 +210,19 @@ def test_from_files(tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check val data data = next(iter(dm.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check test data data = next(iter(dm.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) @@ -309,25 +309,25 @@ def test_from_fiftyone(tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check val data data = next(iter(dm.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check test data data = next(iter(dm.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) # check predict data data = next(iter(dm.predict_dataloader())) - imgs = data[DefaultDataKeys.INPUT] + imgs = data[DataKeys.INPUT] assert imgs.shape == (2, 3, 128, 128) @staticmethod @@ -383,7 +383,7 @@ def test_map_labels(tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2, 128, 128) assert labels.min().item() == 0 diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 41443e470e..821c7bddd6 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -23,7 +23,7 @@ from flash import Trainer from flash.__main__ import main from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image import SemanticSegmentation from flash.image.segmentation.data import SemanticSegmentationInputTransform @@ -38,8 +38,8 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { - DefaultDataKeys.INPUT: torch.rand(3, *self.size), - DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, self.size), + DataKeys.INPUT: torch.rand(3, *self.size), + DataKeys.TARGET: torch.randint(self.num_classes - 1, self.size), } def __len__(self) -> int: @@ -107,7 +107,7 @@ def test_predict_tensor(): img = torch.rand(1, 3, 64, 64) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform(num_classes=1)) - out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) + out = model.predict(img, input="tensors", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 64 assert len(out[0][0]) == 64 @@ -118,7 +118,7 @@ def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform(num_classes=1)) - out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) + out = model.predict(img, input="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) assert len(out[0]) == 64 assert len(out[0][0]) == 64 diff --git a/tests/image/segmentation/test_output.py b/tests/image/segmentation/test_output.py index ad06cd2bd8..767bd49ccf 100644 --- a/tests/image/segmentation/test_output.py +++ b/tests/image/segmentation/test_output.py @@ -14,7 +14,7 @@ import pytest import torch -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image.segmentation.output import FiftyOneSegmentationLabels, SegmentationLabels from tests.helpers.utils import _IMAGE_TESTING @@ -51,7 +51,7 @@ def test_serialize(): sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 - classes = serial.transform({DefaultDataKeys.PREDS: sample}) + classes = serial.serialize({DataKeys.PREDS: sample}) assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3 @@ -67,8 +67,8 @@ def test_serialize_fiftyone(): preds[3, 0, 1] = 1 # add peak in class 4 sample = { - DefaultDataKeys.PREDS: preds, - DefaultDataKeys.METADATA: {"filepath": "something"}, + DataKeys.PREDS: preds, + DataKeys.METADATA: {"filepath": "something"}, } segmentation = serial.transform(sample) diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py index 3fa6248107..56ebf4b078 100644 --- a/tests/pointcloud/detection/test_data.py +++ b/tests/pointcloud/detection/test_data.py @@ -18,7 +18,7 @@ from pytorch_lightning import seed_everything from flash import Trainer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.utils import download_data from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData from tests.helpers.utils import _POINTCLOUD_TESTING @@ -54,5 +54,5 @@ def training_step(self, batch, batch_idx: int): model.eval() predictions = model.predict([join(predict_path, "scans/000000.bin")]) - assert predictions[0][DefaultDataKeys.INPUT].shape[1] == 4 - assert len(predictions[0][DefaultDataKeys.PREDS]) == 158 + assert predictions[0][DataKeys.INPUT].shape[1] == 4 + assert len(predictions[0][DataKeys.PREDS]) == 158 diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py index 400da2c0c4..bce4693f2c 100644 --- a/tests/pointcloud/segmentation/test_data.py +++ b/tests/pointcloud/segmentation/test_data.py @@ -18,7 +18,7 @@ from pytorch_lightning import seed_everything from flash import Trainer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.data.utils import download_data from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData from tests.helpers.utils import _POINTCLOUD_TESTING @@ -35,15 +35,15 @@ def test_pointcloud_segmentation_data(tmpdir): class MockModel(PointCloudSegmentation): def training_step(self, batch, batch_idx: int): - assert batch[DefaultDataKeys.INPUT]["xyz"][0].shape == torch.Size([2, 45056, 3]) - assert batch[DefaultDataKeys.INPUT]["xyz"][1].shape == torch.Size([2, 11264, 3]) - assert batch[DefaultDataKeys.INPUT]["xyz"][2].shape == torch.Size([2, 2816, 3]) - assert batch[DefaultDataKeys.INPUT]["xyz"][3].shape == torch.Size([2, 704, 3]) - assert batch[DefaultDataKeys.INPUT]["labels"].shape == torch.Size([2, 45056]) - assert batch[DefaultDataKeys.INPUT]["labels"].max() == 19 - assert batch[DefaultDataKeys.INPUT]["labels"].min() == 0 - assert batch[DefaultDataKeys.METADATA][0]["name"] in ("00_000000", "00_000001") - assert batch[DefaultDataKeys.METADATA][1]["name"] in ("00_000000", "00_000001") + assert batch[DataKeys.INPUT]["xyz"][0].shape == torch.Size([2, 45056, 3]) + assert batch[DataKeys.INPUT]["xyz"][1].shape == torch.Size([2, 11264, 3]) + assert batch[DataKeys.INPUT]["xyz"][2].shape == torch.Size([2, 2816, 3]) + assert batch[DataKeys.INPUT]["xyz"][3].shape == torch.Size([2, 704, 3]) + assert batch[DataKeys.INPUT]["labels"].shape == torch.Size([2, 45056]) + assert batch[DataKeys.INPUT]["labels"].max() == 19 + assert batch[DataKeys.INPUT]["labels"].min() == 0 + assert batch[DataKeys.METADATA][0]["name"] in ("00_000000", "00_000001") + assert batch[DataKeys.METADATA][1]["name"] in ("00_000000", "00_000001") num_classes = 19 model = MockModel(backbone="randlanet", num_classes=num_classes) @@ -51,6 +51,6 @@ def training_step(self, batch, batch_idx: int): trainer.fit(model, dm) predictions = model.predict(join(tmpdir, "SemanticKittiMicro", "predict")) - assert predictions[0][DefaultDataKeys.INPUT].shape == torch.Size([45056, 3]) - assert predictions[0][DefaultDataKeys.PREDS].shape == torch.Size([45056, 19]) - assert predictions[0][DefaultDataKeys.TARGET].shape == torch.Size([45056]) + assert predictions[0][DataKeys.INPUT].shape == torch.Size([45056, 3]) + assert predictions[0][DataKeys.PREDS].shape == torch.Size([45056, 19]) + assert predictions[0][DataKeys.TARGET].shape == torch.Size([45056]) diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index 4f1fc632d8..b8d636d87b 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _PANDAS_AVAILABLE if _PANDAS_AVAILABLE: @@ -110,8 +110,8 @@ def test_categorical_target(tmpdir): ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) - (cat, num) = data[DefaultDataKeys.INPUT] - target = data[DefaultDataKeys.TARGET] + (cat, num) = data[DataKeys.INPUT] + target = data[DataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1,) @@ -134,8 +134,8 @@ def test_from_data_frame(tmpdir): ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) - (cat, num) = data[DefaultDataKeys.INPUT] - target = data[DefaultDataKeys.TARGET] + (cat, num) = data[DataKeys.INPUT] + target = data[DataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1,) @@ -161,8 +161,8 @@ def test_from_csv(tmpdir): ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) - (cat, num) = data[DefaultDataKeys.INPUT] - target = data[DefaultDataKeys.TARGET] + (cat, num) = data[DataKeys.INPUT] + target = data[DataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) assert target.shape == (1,) diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 9ebfe8eb97..43df635fb8 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -19,7 +19,7 @@ import torch from pytorch_lightning import Trainer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.tabular.classification.data import TabularClassificationData from flash.tabular.classification.model import TabularClassifier @@ -38,7 +38,7 @@ def __getitem__(self, index): target = torch.randint(0, 10, size=(1,)).item() cat_vars = torch.randint(0, 10, size=(self.num_cat,)) num_vars = torch.rand(self.num_num) - return {DefaultDataKeys.INPUT: (cat_vars, num_vars), DefaultDataKeys.TARGET: target} + return {DataKeys.INPUT: (cat_vars, num_vars), DataKeys.TARGET: target} def __len__(self) -> int: return 100 diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py index 03f11b3c81..a1ff989c7b 100644 --- a/tests/template/classification/test_data.py +++ b/tests/template/classification/test_data.py @@ -14,7 +14,7 @@ import numpy as np import pytest -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.template.classification.data import TemplateData, TemplateInputTransform @@ -69,19 +69,19 @@ def test_from_numpy(self): # check training data data = next(iter(dm.train_dataloader())) - rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET] assert rows.shape == (2, self.num_features) assert targets.shape == (2,) # check val data data = next(iter(dm.val_dataloader())) - rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET] assert rows.shape == (2, self.num_features) assert targets.shape == (2,) # check test data data = next(iter(dm.test_dataloader())) - rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET] assert rows.shape == (2, self.num_features) assert targets.shape == (2,) @@ -105,18 +105,18 @@ def test_from_sklearn(): # check training data data = next(iter(dm.train_dataloader())) - rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET] assert rows.shape == (2, dm.num_features) assert targets.shape == (2,) # check val data data = next(iter(dm.val_dataloader())) - rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET] assert rows.shape == (2, dm.num_features) assert targets.shape == (2,) # check test data data = next(iter(dm.test_dataloader())) - rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET] assert rows.shape == (2, dm.num_features) assert targets.shape == (2,) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index 0c585b842d..a3ab60d97c 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -19,7 +19,7 @@ from flash import Trainer from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.template import TemplateSKLearnClassifier from flash.template.classification.data import TemplateInputTransform @@ -38,8 +38,8 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { - DefaultDataKeys.INPUT: torch.randn(self.num_features), - DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1,))[0], + DataKeys.INPUT: torch.randn(self.num_features), + DataKeys.TARGET: torch.randint(self.num_classes - 1, (1,))[0], } def __len__(self) -> int: @@ -116,7 +116,7 @@ def test_predict_sklearn(): bunch = datasets.load_iris() model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) data_pipe = DataPipeline(input_transform=TemplateInputTransform()) - out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe) + out = model.predict(bunch, input="sklearn", data_pipeline=data_pipe) assert isinstance(out[0], int) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index fc8e8e15d0..3453f80260 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -17,17 +17,17 @@ import pandas as pd import pytest -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassificationData from flash.text.classification.data import ( - TextCSVDataSource, - TextDataFrameDataSource, - TextDataSource, - TextHuggingFaceDatasetDataSource, - TextJSONDataSource, - TextListDataSource, - TextParquetDataSource, + TextCSVInput, + TextDataFrameInput, + TextHuggingFaceDatasetInput, + TextInput, + TextJSONInput, + TextListInput, + TextParquetInput, ) from tests.helpers.utils import _TEXT_TESTING @@ -135,15 +135,15 @@ def test_from_csv(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -166,15 +166,15 @@ def test_from_csv_multilabel(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -197,15 +197,15 @@ def test_from_json(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -228,15 +228,15 @@ def test_from_json_multilabel(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -260,15 +260,15 @@ def test_from_json_with_field(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -292,15 +292,15 @@ def test_from_json_with_field_multilabel(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -323,15 +323,15 @@ def test_from_parquet(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -354,15 +354,15 @@ def test_from_parquet_multilabel(tmpdir): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -384,15 +384,15 @@ def test_from_data_frame(): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -414,15 +414,15 @@ def test_from_data_frame_multilabel(): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -445,15 +445,15 @@ def test_from_hf_datasets(): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -476,15 +476,15 @@ def test_from_hf_datasets_multilabel(): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -507,15 +507,15 @@ def test_from_lists(): ) batch = next(iter(dm.train_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert batch[DefaultDataKeys.TARGET].item() in [0, 1] + assert batch[DataKeys.TARGET].item() in [0, 1] assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -538,15 +538,15 @@ def test_from_lists_multilabel(): ) batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DefaultDataKeys.TARGET][0]]) + assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) assert "input_ids" in batch batch = next(iter(dm.predict_dataloader())) @@ -564,13 +564,13 @@ def test_text_module_not_found_error(): @pytest.mark.parametrize( "cls, kwargs", [ - (TextDataSource, {}), - (TextCSVDataSource, {}), - (TextJSONDataSource, {}), - (TextDataFrameDataSource, {}), - (TextParquetDataSource, {}), - (TextHuggingFaceDatasetDataSource, {}), - (TextListDataSource, {}), + (TextInput, {}), + (TextCSVInput, {}), + (TextJSONInput, {}), + (TextDataFrameInput, {}), + (TextParquetInput, {}), + (TextHuggingFaceDatasetInput, {}), + (TextListInput, {}), ], ) def test_tokenizer_state(cls, kwargs): diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index e336ff81c5..db3db985b4 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -20,7 +20,7 @@ from flash import Trainer from flash.__main__ import main -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassifier from flash.text.classification.data import TextClassificationInputTransform, TextClassificationOutputTransform @@ -33,7 +33,7 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { "input_ids": torch.randint(1000, size=(100,)), - DefaultDataKeys.TARGET: torch.randint(2, size=(1,)).item(), + DataKeys.TARGET: torch.randint(2, size=(1,)).item(), } def __len__(self) -> int: diff --git a/tests/text/question_answering/test_data.py b/tests/text/question_answering/test_data.py index a57bf4f2dd..862f8c3b0b 100644 --- a/tests/text/question_answering/test_data.py +++ b/tests/text/question_answering/test_data.py @@ -18,7 +18,7 @@ import pandas as pd import pytest -from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.io.input import DataKeys from flash.text import QuestionAnsweringData from tests.helpers.utils import _TEXT_TESTING @@ -138,22 +138,22 @@ def test_from_files(tmpdir): assert "attention_mask" in batch assert "start_positions" in batch assert "end_positions" in batch - assert DefaultDataKeys.METADATA in batch - assert "context" in batch[DefaultDataKeys.METADATA][0] - assert "answer" in batch[DefaultDataKeys.METADATA][0] - assert "example_id" in batch[DefaultDataKeys.METADATA][0] - assert "offset_mapping" in batch[DefaultDataKeys.METADATA][0] + assert DataKeys.METADATA in batch + assert "context" in batch[DataKeys.METADATA][0] + assert "answer" in batch[DataKeys.METADATA][0] + assert "example_id" in batch[DataKeys.METADATA][0] + assert "offset_mapping" in batch[DataKeys.METADATA][0] batch = next(iter(dm.test_dataloader())) assert "input_ids" in batch assert "attention_mask" in batch assert "start_positions" in batch assert "end_positions" in batch - assert DefaultDataKeys.METADATA in batch - assert "context" in batch[DefaultDataKeys.METADATA][0] - assert "answer" in batch[DefaultDataKeys.METADATA][0] - assert "example_id" in batch[DefaultDataKeys.METADATA][0] - assert "offset_mapping" in batch[DefaultDataKeys.METADATA][0] + assert DataKeys.METADATA in batch + assert "context" in batch[DataKeys.METADATA][0] + assert "answer" in batch[DataKeys.METADATA][0] + assert "example_id" in batch[DataKeys.METADATA][0] + assert "offset_mapping" in batch[DataKeys.METADATA][0] @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py index 8b9e2f862c..d38f89c63c 100644 --- a/tests/text/seq2seq/core/test_data.py +++ b/tests/text/seq2seq/core/test_data.py @@ -18,12 +18,12 @@ from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text.seq2seq.core.data import ( Seq2SeqBackboneState, - Seq2SeqCSVDataSource, - Seq2SeqDataSource, - Seq2SeqFileDataSource, - Seq2SeqJSONDataSource, + Seq2SeqCSVInput, + Seq2SeqFileInput, + Seq2SeqInput, + Seq2SeqJSONInput, Seq2SeqOutputTransform, - Seq2SeqSentencesDataSource, + Seq2SeqSentencesInput, ) from tests.helpers.utils import _TEXT_TESTING @@ -36,11 +36,11 @@ @pytest.mark.parametrize( "cls, kwargs", [ - (Seq2SeqDataSource, {"backbone": "sshleifer/tiny-mbart"}), - (Seq2SeqFileDataSource, {"backbone": "sshleifer/tiny-mbart", "filetype": "csv"}), - (Seq2SeqCSVDataSource, {"backbone": "sshleifer/tiny-mbart"}), - (Seq2SeqJSONDataSource, {"backbone": "sshleifer/tiny-mbart"}), - (Seq2SeqSentencesDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqInput, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqFileInput, {"backbone": "sshleifer/tiny-mbart", "filetype": "csv"}), + (Seq2SeqCSVInput, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqJSONInput, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqSentencesInput, {"backbone": "sshleifer/tiny-mbart"}), (Seq2SeqOutputTransform, {}), ], )