diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6d0283c18c..f60c2461c8 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -26,3 +26,4 @@ /.github/*.md @edenlightning @ethanwharris @ananyahjha93 /.github/ISSUE_TEMPLATE/*.md @edenlightning @ethanwharris @ananyahjha93 /docs/source/conf.py @borda @ethanwharris @ananyahjha93 +/flash/core/integrations/labelstudio @KonstantinKorotaev @niklub diff --git a/CHANGELOG.md b/CHANGELOG.md index 11d4b0accf..91f7061a4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `LabelStudio` integration ([#554](https://github.com/PyTorchLightning/lightning-flash/pull/554)) + - Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737)) - Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682)) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 970e73160f..fdf3f22e48 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -1246,3 +1246,128 @@ def from_fiftyone( num_workers=num_workers, **preprocess_kwargs, ) + + @classmethod + def from_labelstudio( + cls, + export_json: str = None, + train_export_json: str = None, + val_export_json: str = None, + test_export_json: str = None, + predict_export_json: str = None, + data_folder: str = None, + train_data_folder: str = None, + val_data_folder: str = None, + test_data_folder: str = None, + predict_data_folder: str = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.core.data.data_module.DataModule` object + from the given export file and data directory using the + :class:`~flash.core.data.data_source.DataSource` of name + :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` + from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + + Args: + export_json: path to label studio export file + train_export_json: path to label studio export file for train set, + overrides export_json if specified + val_export_json: path to label studio export file for validation + test_export_json: path to label studio export file for test + predict_export_json: path to label studio export file for predict + data_folder: path to label studio data folder + train_data_folder: path to label studio data folder for train data set, + overrides data_folder if specified + val_data_folder: path to label studio data folder for validation data + test_data_folder: path to label studio data folder for test data + predict_data_folder: path to label studio data folder for predict data + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_labelstudio( + export_json='project.json', + data_folder='label-studio/media/upload', + val_split=0.8, + ) + """ + data = { + "data_folder": data_folder, + "export_json": export_json, + "split": val_split, + "multi_label": preprocess_kwargs.get("multi_label", False), + } + train_data = None + val_data = None + test_data = None + predict_data = None + if (train_data_folder or data_folder) and train_export_json: + train_data = { + "data_folder": train_data_folder or data_folder, + "export_json": train_export_json, + "multi_label": preprocess_kwargs.get("multi_label", False), + } + if (val_data_folder or data_folder) and val_export_json: + val_data = { + "data_folder": val_data_folder or data_folder, + "export_json": val_export_json, + "multi_label": preprocess_kwargs.get("multi_label", False), + } + if (test_data_folder or data_folder) and test_export_json: + test_data = { + "data_folder": test_data_folder or data_folder, + "export_json": test_export_json, + "multi_label": preprocess_kwargs.get("multi_label", False), + } + if (predict_data_folder or data_folder) and predict_export_json: + predict_data = { + "data_folder": predict_data_folder or data_folder, + "export_json": predict_export_json, + "multi_label": preprocess_kwargs.get("multi_label", False), + } + return cls.from_data_source( + DefaultDataSources.LABELSTUDIO, + train_data=train_data if train_data else data, + val_data=val_data, + test_data=test_data, + predict_data=predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index 6b3e53dea9..fb4260ed89 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -160,6 +160,7 @@ class DefaultDataSources(LightningEnum): JSON = "json" DATASETS = "datasets" FIFTYONE = "fiftyone" + LABELSTUDIO = "labelstudio" # TODO: Create a FlashEnum class??? def __hash__(self) -> int: diff --git a/flash/core/integrations/labelstudio/data_source.py b/flash/core/integrations/labelstudio/data_source.py new file mode 100644 index 0000000000..5e2587f7f6 --- /dev/null +++ b/flash/core/integrations/labelstudio/data_source.py @@ -0,0 +1,266 @@ +import json +import os +from pathlib import Path +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union + +import torch +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.cloud_io import get_filesystem + +from flash import DataSource +from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset +from flash.core.data.data_source import DefaultDataKeys, has_len +from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets.folder import default_loader +DATA_TYPE = TypeVar("DATA_TYPE") + + +class LabelStudioDataSource(DataSource): + """The ``LabelStudioDatasource`` expects the input to + :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio.""" + + def __init__(self): + super().__init__() + self.results = [] + self.test_results = [] + self.val_results = [] + self.classes = set() + self.data_types = set() + self.tag_types = set() + self.num_classes = 0 + self._data_folder = "" + self._raw_data = {} + self.multi_label = False + self.split = None + + def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + """Iterate through all tasks in exported data and construct train\test\val results.""" + if data and isinstance(data, dict): + data_folder = data.get("data_folder") + file_path = data.get("export_json") + fs = get_filesystem(file_path) + with fs.open(file_path) as f: + _raw_data = json.load(f) + self.multi_label = data.get("multi_label", False) + self.split = data.get("split") + results, test_results, classes, data_types, tag_types = LabelStudioDataSource._load_json_data( + _raw_data, data_folder=data_folder, multi_label=self.multi_label + ) + self.classes = self.classes | classes + self.data_types = self.data_types | data_types + self.num_classes = len(self.classes) + self.tag_types = self.tag_types | tag_types + # splitting result to train and val sets + if self.split: + import random + + random.shuffle(results) + prop = int(len(results) * self.split) + self.val_results = results[:prop] + self.results = results[prop:] + self.test_results = test_results + return self.results + return results + test_results + return [] + + def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: + """Load 1 sample from dataset.""" + # all other data types + # separate label from data + label = self._get_labels_from_sample(sample["label"]) + # delete label from input data + del sample["label"] + result = { + DefaultDataKeys.INPUT: sample, + DefaultDataKeys.TARGET: label, + } + return result + + def generate_dataset( + self, + data: Optional[DATA_TYPE], + running_stage: RunningStage, + ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: + """Generate dataset from loaded data.""" + res = self.load_data(data) + if running_stage in (RunningStage.TRAINING, RunningStage.TUNING): + dataset = res + elif running_stage == RunningStage.TESTING: + dataset = res or self.test_results + elif running_stage == RunningStage.PREDICTING: + dataset = res or [] + elif running_stage == RunningStage.VALIDATING: + dataset = res or self.val_results + + if has_len(dataset): + dataset = AutoDataset(dataset, self, running_stage) + else: + dataset = IterableAutoDataset(dataset, self, running_stage) + dataset.num_classes = self.num_classes + return dataset + + def _get_labels_from_sample(self, labels): + """Translate string labels to int.""" + sorted_labels = sorted(list(self.classes)) + if isinstance(labels, list): + label = [] + for item in labels: + label.append(sorted_labels.index(item)) + else: + label = sorted_labels.index(labels) + return label + + @staticmethod + def _load_json_data(data, data_folder, multi_label=False): + """Utility method to extract data from Label Studio json files.""" + results = [] + test_results = [] + data_types = set() + tag_types = set() + classes = set() + for task in data: + for annotation in task["annotations"]: + # extracting data types from tasks + for key in task.get("data"): + data_types.add(key) + # Adding ground_truth annotation to separate dataset + result = annotation["result"] + for res in result: + t = res["type"] + tag_types.add(t) + for label in res["value"][t]: + # check if labeling result is a list of labels + if isinstance(label, list) and not multi_label: + for sublabel in label: + classes.add(sublabel) + temp = {} + temp["file_upload"] = task.get("file_upload") + temp["data"] = task.get("data") + if temp["file_upload"]: + temp["file_upload"] = os.path.join(data_folder, temp["file_upload"]) + else: + for key in temp["data"]: + p = temp["data"].get(key) + path = Path(p) + if path and data_folder: + temp["file_upload"] = os.path.join(data_folder, path.name) + temp["label"] = sublabel + temp["result"] = res.get("value") + if annotation["ground_truth"]: + test_results.append(temp) + elif not annotation["ground_truth"]: + results.append(temp) + else: + if isinstance(label, list): + for item in label: + classes.add(item) + else: + classes.add(label) + temp = {} + temp["file_upload"] = task.get("file_upload") + temp["data"] = task.get("data") + if temp["file_upload"] and data_folder: + temp["file_upload"] = os.path.join(data_folder, temp["file_upload"]) + else: + for key in temp["data"]: + p = temp["data"].get(key) + path = Path(p) + if path and data_folder: + temp["file_upload"] = os.path.join(data_folder, path.name) + temp["label"] = label + temp["result"] = res.get("value") + if annotation["ground_truth"]: + test_results.append(temp) + elif not annotation["ground_truth"]: + results.append(temp) + return results, test_results, classes, data_types, tag_types + + +class LabelStudioImageClassificationDataSource(LabelStudioDataSource): + """The ``LabelStudioImageDataSource`` expects the input to + :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio. + Export data should point to image files""" + + def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: + """Load 1 sample from dataset.""" + p = sample["file_upload"] + # loading image + image = default_loader(p) + result = {DefaultDataKeys.INPUT: image, DefaultDataKeys.TARGET: self._get_labels_from_sample(sample["label"])} + return result + + +class LabelStudioTextClassificationDataSource(LabelStudioDataSource): + """The ``LabelStudioTextDataSource`` expects the input to + :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio. + Export data should point to text data + """ + + def __init__(self, backbone=None, max_length=128): + super().__init__() + if backbone: + if _TEXT_AVAILABLE: + from transformers import AutoTokenizer + self.backbone = backbone + self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + self.max_length = max_length + + def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: + """Load 1 sample from dataset.""" + if self.backbone: + data = "" + for key in sample.get("data"): + data += sample.get("data").get(key) + tokenized_data = self.tokenizer(data, max_length=self.max_length, truncation=True, padding="max_length") + for key in tokenized_data: + tokenized_data[key] = torch.tensor(tokenized_data[key]) + tokenized_data["labels"] = self._get_labels_from_sample(sample["label"]) + # separate text data type block + result = tokenized_data + return result + + +class LabelStudioVideoClassificationDataSource(LabelStudioDataSource): + """The ``LabelStudioVideoDataSource`` expects the input to + :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio. + Export data should point to video files""" + + def __init__(self, video_sampler=None, clip_sampler=None, decode_audio=False, decoder: str = "pyav"): + if not _PYTORCHVIDEO_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install pytorchvideo`.") + super().__init__() + self.video_sampler = video_sampler or torch.utils.data.RandomSampler + self.clip_sampler = clip_sampler + self.decode_audio = decode_audio + self.decoder = decoder + + def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: + """Load 1 sample from dataset.""" + return sample + + def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + """load_data produces a sequence or iterable of samples.""" + res = super().load_data(data, dataset) + return self.convert_to_encodedvideo(res) + + def convert_to_encodedvideo(self, dataset): + """Converting dataset to EncodedVideoDataset.""" + if len(dataset) > 0: + from pytorchvideo.data import LabeledVideoDataset + + dataset = LabeledVideoDataset( + [ + ( + os.path.join(self._data_folder, sample["file_upload"]), + {"label": self._get_labels_from_sample(sample["label"])}, + ) + for sample in dataset + ], + self.clip_sampler, + decode_audio=self.decode_audio, + decoder=self.decoder, + ) + return dataset + return [] diff --git a/flash/core/integrations/labelstudio/visualizer.py b/flash/core/integrations/labelstudio/visualizer.py new file mode 100644 index 0000000000..a284eee10b --- /dev/null +++ b/flash/core/integrations/labelstudio/visualizer.py @@ -0,0 +1,101 @@ +import json +import random +import string + +from pytorch_lightning.utilities.cloud_io import get_filesystem + +from flash.core.data.data_module import DataModule + + +class App: + """App for visualizing predictions in Label Studio results format.""" + + def __init__(self, datamodule: DataModule): + self.datamodule = datamodule + + def show_predictions(self, predictions): + """Converts predictions to Label Studio results.""" + results = [] + for pred in predictions: + results.append(self.construct_result(pred)) + return results + + def show_tasks(self, predictions, export_json=None): + """Converts predictions to tasks format.""" + results = self.show_predictions(predictions) + ds = self.datamodule.data_source + data_type = list(ds.data_types)[0] + meta = {"ids": [], "data": [], "meta": [], "max_predictions_id": 0, "project": None} + if export_json: + fs = get_filesystem(export_json) + with fs.open(export_json) as f: + _raw_data = json.load(f) + for task in _raw_data: + if results: + res = results.pop() + meta["max_predictions_id"] = meta["max_predictions_id"] + 1 + temp = { + "result": res["result"], + "id": meta["max_predictions_id"], + "model_version": "", + "score": 0.0, + "task": task["id"], + } + if task.get("predictions"): + task["predictions"].append(temp) + else: + task["predictions"] = [temp] + return _raw_data + else: + print("No export file provided, meta information is generated!") + final_results = [] + for res in results: + temp = { + "result": [res], + "id": meta["max_predictions_id"], + "model_version": "", + "score": 0.0, + "task": meta["max_predictions_id"], + } + task = { + "id": meta["max_predictions_id"], + "predictions": [temp], + "data": {data_type: ""}, + "project": 1, + } + meta["max_predictions_id"] = meta["max_predictions_id"] + 1 + final_results.append(task) + return final_results + + def construct_result(self, pred): + """Construction Label Studio result from data source and prediction values.""" + ds = self.datamodule.data_source + # get label + if isinstance(pred, list): + label = [list(ds.classes)[p] for p in pred] + else: + label = list(ds.classes)[pred] + # get data type, if len(data_types) > 1 take first data type + data_type = list(ds.data_types)[0] + # get tag type, if len(tag_types) > 1 take first tag + tag_type = list(ds.tag_types)[0] + js = { + "result": [ + { + "id": "".join( + random.SystemRandom().choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) + for _ in range(10) + ), + "from_name": "tag", + "to_name": data_type, + "type": tag_type, + "value": {tag_type: label if isinstance(label, list) else [label]}, + } + ] + } + return js + + +def launch_app(datamodule: DataModule) -> "App": + """Creating instance of Visualizing App.""" + return App(datamodule) diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 5f035949c2..af389fc9ba 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -24,6 +24,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, LoaderDataFrameDataSource from flash.core.data.process import Deserializer, Preprocess +from flash.core.integrations.labelstudio.data_source import LabelStudioImageClassificationDataSource from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires from flash.image.classification.transforms import default_transforms, train_default_transforms from flash.image.data import ( @@ -79,6 +80,7 @@ def __init__( DefaultDataSources.TENSORS: ImageTensorDataSource(), "data_frame": ImageClassificationDataFrameDataSource(), DefaultDataSources.CSV: ImageClassificationDataFrameDataSource(), + DefaultDataSources.LABELSTUDIO: LabelStudioImageClassificationDataSource(), }, deserializer=deserializer or ImageDeserializer(), default_data_source=DefaultDataSources.FILES, diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index 04dd8cd3b4..52bd44cd77 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -1,18 +1,18 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from flash.tabular.data import TabularData - - -class TabularRegressionData(TabularData): - is_regression = True +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.tabular.data import TabularData + + +class TabularRegressionData(TabularData): + is_regression = True diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index c71538c0b9..085b30988c 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -25,6 +25,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources, LabelsState from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.integrations.labelstudio.data_source import LabelStudioTextClassificationDataSource from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: @@ -331,6 +332,9 @@ def __init__( DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), "data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length), "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), + DefaultDataSources.LABELSTUDIO: LabelStudioTextClassificationDataSource( + backbone=self.backbone, max_length=max_length + ), }, default_data_source="sentences", deserializer=TextDeserializer(backbone, max_length), diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 0d6757d061..cf69c97883 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -28,6 +28,7 @@ PathsDataSource, ) from flash.core.data.process import Preprocess +from flash.core.integrations.labelstudio.data_source import LabelStudioVideoClassificationDataSource from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import SampleCollection = None @@ -256,6 +257,13 @@ def __init__( decoder=decoder, **data_source_kwargs, ), + DefaultDataSources.LABELSTUDIO: LabelStudioVideoClassificationDataSource( + clip_sampler=clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + **data_source_kwargs, + ), }, default_data_source=DefaultDataSources.FILES, ) diff --git a/flash_examples/integrations/labelstudio/image_classification.py b/flash_examples/integrations/labelstudio/image_classification.py new file mode 100644 index 0000000000..41e3aa7332 --- /dev/null +++ b/flash_examples/integrations/labelstudio/image_classification.py @@ -0,0 +1,45 @@ +import flash +from flash.core.classification import Labels +from flash.core.data.utils import download_data +from flash.core.finetuning import FreezeUnfreeze +from flash.core.integrations.labelstudio.visualizer import launch_app +from flash.image import ImageClassificationData, ImageClassifier + +# 1 Download data +download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") + +# 2. Load export data +datamodule = ImageClassificationData.from_labelstudio( + export_json="data/project.json", + data_folder="data/upload/", + val_split=0.2, +) + +# 3. Fine tune a model +model = ImageClassifier( + backbone="resnet18", + num_classes=datamodule.num_classes, +) +trainer = flash.Trainer(max_epochs=3) + +trainer.finetune( + model, + datamodule=datamodule, + strategy=FreezeUnfreeze(unfreeze_epoch=1), +) +trainer.save_checkpoint("image_classification_model.pt") + +# 4. Predict from checkpoint +model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") +model.serializer = Labels() + +predictions = model.predict( + [ + "data/test/1.jpg", + "data/test/2.jpg", + ] +) + +# 5. Visualize predictions +app = launch_app(datamodule) +print(app.show_predictions(predictions)) diff --git a/flash_examples/integrations/labelstudio/text_classification.py b/flash_examples/integrations/labelstudio/text_classification.py new file mode 100644 index 0000000000..930b75bc07 --- /dev/null +++ b/flash_examples/integrations/labelstudio/text_classification.py @@ -0,0 +1,38 @@ +import flash +from flash.core.data.utils import download_data +from flash.core.integrations.labelstudio.visualizer import launch_app +from flash.text import TextClassificationData, TextClassifier + +# 1. Create the DataModule +download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") + +backbone = "prajjwal1/bert-medium" + +datamodule = TextClassificationData.from_labelstudio( + export_json="data/project.json", + val_split=0.2, + backbone=backbone, +) + +# 2. Build the task +model = TextClassifier(backbone=backbone, num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Classify a few sentences! How was the movie? +predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + ] +) + +# 5. Save the model! +trainer.save_checkpoint("text_classification_model.pt") + +# 6. Visualize predictions +app = launch_app(datamodule) +print(app.show_predictions(predictions)) diff --git a/flash_examples/integrations/labelstudio/video_classification.py b/flash_examples/integrations/labelstudio/video_classification.py new file mode 100644 index 0000000000..c9e76c88ab --- /dev/null +++ b/flash_examples/integrations/labelstudio/video_classification.py @@ -0,0 +1,40 @@ +import os + +import flash +from flash.core.data.utils import download_data +from flash.core.integrations.labelstudio.visualizer import launch_app +from flash.video import VideoClassificationData, VideoClassifier + +# 1 Download data +download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") + +# 2. Load export data +datamodule = VideoClassificationData.from_labelstudio( + export_json="data/project.json", + data_folder="data/upload/", + val_split=0.2, + clip_sampler="uniform", + clip_duration=1, + decode_audio=False, +) + +# 3. Build the task +model = VideoClassifier( + backbone="slow_r50", + num_classes=datamodule.num_classes, +) + +# 4. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 5. Make a prediction +predictions = model.predict(os.path.join(os.getcwd(), "data/test")) +print(predictions) + +# 6. Save the model! +trainer.save_checkpoint("video_classification.pt") + +# 7. Visualize predictions +app = launch_app(datamodule) +print(app.show_predictions(predictions)) diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py new file mode 100644 index 0000000000..4f586dc00b --- /dev/null +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -0,0 +1,299 @@ +import pytest + +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.utils import download_data +from flash.core.integrations.labelstudio.data_source import ( + LabelStudioDataSource, + LabelStudioImageClassificationDataSource, + LabelStudioTextClassificationDataSource, +) +from flash.core.integrations.labelstudio.visualizer import launch_app +from flash.image.classification.data import ImageClassificationData +from flash.text.classification.data import TextClassificationData +from flash.video.classification.data import VideoClassificationData, VideoClassificationPreprocess +from tests.helpers.utils import _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING + + +def test_utility_load(): + """Test for label studio json loader.""" + data = [ + { + "id": 191, + "annotations": [ + { + "id": 130, + "completed_by": {"id": 1, "email": "test@heartex.com", "first_name": "", "last_name": ""}, + "result": [ + { + "id": "dv1Tn-zdez", + "type": "rectanglelabels", + "value": { + "x": 46.5625, + "y": 21.666666666666668, + "width": 8.75, + "height": 12.083333333333334, + "rotation": 0, + "rectanglelabels": ["Car"], + }, + "to_name": "image", + "from_name": "label", + "image_rotation": 0, + "original_width": 320, + "original_height": 240, + }, + { + "id": "KRa8jEvpK0", + "type": "rectanglelabels", + "value": { + "x": 66.875, + "y": 22.5, + "width": 14.0625, + "height": 17.5, + "rotation": 0, + "rectanglelabels": ["Car"], + }, + "to_name": "image", + "from_name": "label", + "image_rotation": 0, + "original_width": 320, + "original_height": 240, + }, + { + "id": "kAKaSxNnvH", + "type": "rectanglelabels", + "value": { + "x": 93.4375, + "y": 22.916666666666668, + "width": 6.5625, + "height": 18.75, + "rotation": 0, + "rectanglelabels": ["Car"], + }, + "to_name": "image", + "from_name": "label", + "image_rotation": 0, + "original_width": 320, + "original_height": 240, + }, + { + "id": "_VXKV2nz14", + "type": "rectanglelabels", + "value": { + "x": 0, + "y": 39.583333333333336, + "width": 100, + "height": 60.416666666666664, + "rotation": 0, + "rectanglelabels": ["Road"], + }, + "to_name": "image", + "from_name": "label", + "image_rotation": 0, + "original_width": 320, + "original_height": 240, + }, + { + "id": "vCuvi_jLHn", + "type": "rectanglelabels", + "value": { + "x": 0, + "y": 17.5, + "width": 48.125, + "height": 41.66666666666666, + "rotation": 0, + "rectanglelabels": ["Obstacle"], + }, + "to_name": "image", + "from_name": "label", + "image_rotation": 0, + "original_width": 320, + "original_height": 240, + }, + ], + "was_cancelled": False, + "ground_truth": False, + "prediction": {}, + "result_count": 0, + "task": 191, + } + ], + "file_upload": "Highway20030201_1002591.jpg", + "data": {"image": "/data/upload/Highway20030201_1002591.jpg"}, + "meta": {}, + "created_at": "2021-05-12T18:43:41.241095Z", + "updated_at": "2021-05-12T19:42:28.156609Z", + "project": 7, + } + ] + ds = LabelStudioDataSource._load_json_data(data=data, data_folder=".", multi_label=False) + assert ds[3] == {"image"} + assert ds[2] == {"Road", "Car", "Obstacle"} + assert len(ds[1]) == 0 + assert len(ds[0]) == 5 + ds_multi = LabelStudioDataSource._load_json_data(data=data, data_folder=".", multi_label=True) + assert ds_multi[3] == {"image"} + assert ds_multi[2] == {"Road", "Car", "Obstacle"} + assert len(ds_multi[1]) == 0 + assert len(ds_multi[0]) == 5 + + +def test_datasource_labelstudio(): + """Test creation of LabelStudioDataSource.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") + ds = LabelStudioDataSource() + data = { + "data_folder": "data/upload/", + "export_json": "data/project.json", + "split": 0.2, + "multi_label": False, + } + train, val, test, predict = ds.to_datasets(train_data=data) + train_sample = train[0] + val_sample = val[0] + assert train_sample + assert val_sample + assert test + assert not predict + ds_no_split = LabelStudioDataSource() + data = { + "data_folder": "data/upload/", + "export_json": "data/project.json", + "multi_label": True, + } + train, val, test, predict = ds_no_split.to_datasets(train_data=data) + sample = train[0] + assert sample + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_datasource_labelstudio_image(): + """Test creation of LabelStudioImageClassificationDataSource from images.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data_nofile.zip") + + data = { + "data_folder": "data/upload/", + "export_json": "data/project_nofile.json", + "split": 0.2, + "multi_label": True, + } + ds = LabelStudioImageClassificationDataSource() + train, val, test, predict = ds.to_datasets(train_data=data, val_data=data, test_data=data, predict_data=data) + train_sample = train[0] + val_sample = val[0] + test_sample = test[0] + predict_sample = predict[0] + assert train_sample + assert val_sample + assert test_sample + assert predict_sample + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_datamodule_labelstudio_image(): + """Test creation of LabelStudioImageClassificationDataSource and Datamodule from images.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") + + datamodule = ImageClassificationData.from_labelstudio( + train_export_json="data/project.json", + train_data_folder="data/upload/", + test_export_json="data/project.json", + test_data_folder="data/upload/", + val_split=0.5, + ) + assert datamodule + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_label_studio_predictions_visualization(): + """Test creation of LabelStudioImageClassificationDataSource and Datamodule from images.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") + + datamodule = ImageClassificationData.from_labelstudio( + train_export_json="data/project.json", + train_data_folder="data/upload/", + test_export_json="data/project.json", + test_data_folder="data/upload/", + val_split=0.5, + ) + assert datamodule + app = launch_app(datamodule) + predictions = [0, 1, 1, 0] + vis_predictions = app.show_predictions(predictions) + assert len(vis_predictions) == 4 + assert vis_predictions[0]["result"][0]["id"] != vis_predictions[3]["result"][0]["id"] + assert vis_predictions[1]["result"][0]["id"] != vis_predictions[2]["result"][0]["id"] + tasks_predictions = app.show_tasks(predictions) + assert len(tasks_predictions) == 4 + tasks_predictions_json = app.show_tasks(predictions, export_json="data/project.json") + assert tasks_predictions_json + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_datasource_labelstudio_text(): + """Test creation of LabelStudioTextClassificationDataSource and Datamodule from text.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") + backbone = "prajjwal1/bert-medium" + data = { + "data_folder": "data/upload/", + "export_json": "data/project.json", + "split": 0.2, + "multi_label": False, + } + ds = LabelStudioTextClassificationDataSource(backbone=backbone) + train, val, test, predict = ds.to_datasets(train_data=data, test_data=data) + train_sample = train[0] + test_sample = test[0] + val_sample = val[0] + assert train_sample + assert test_sample + assert val_sample + assert not predict + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_datamodule_labelstudio_text(): + """Test creation of LabelStudioTextClassificationDataSource and Datamodule from text.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") + backbone = "prajjwal1/bert-medium" + datamodule = TextClassificationData.from_labelstudio( + train_export_json="data/project.json", + val_export_json="data/project.json", + test_export_json="data/project.json", + predict_export_json="data/project.json", + data_folder="data/upload/", + val_split=0.8, + backbone=backbone, + ) + assert datamodule + + +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_datasource_labelstudio_video(): + """Test creation of LabelStudioVideoClassificationDataSource from video.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") + data = {"data_folder": "data/upload/", "export_json": "data/project.json", "multi_label": True} + preprocess = VideoClassificationPreprocess() + ds = preprocess.data_source_of_name(DefaultDataSources.LABELSTUDIO) + train, val, test, predict = ds.to_datasets(train_data=data, test_data=data) + sample_iter = iter(train) + sample = next(sample_iter) + assert train + assert not val + assert test + assert not predict + assert sample + + +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_datamodule_labelstudio_video(): + """Test creation of Datamodule from video.""" + download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") + datamodule = VideoClassificationData.from_labelstudio( + export_json="data/project.json", + data_folder="data/upload/", + val_split=0.2, + clip_sampler="uniform", + clip_duration=1, + decode_audio=False, + ) + assert datamodule