diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 4d4684e4db..6dbbcabc0e 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -25,7 +25,7 @@ jobs: # ToDo: this need to have installed docker in the base image... #container: "pytorchlightning/pytorch_lightning:base-cuda-py$[ variables['python.version'] ]-torch1.6" container: - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.8-torch1.7" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.8-torch1.8" #endpoint: azureContainerRegistryConnection options: "--ipc=host --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" diff --git a/.deepsource.toml b/.deepsource.toml index 3300d8f939..ea8a9439b1 100644 --- a/.deepsource.toml +++ b/.deepsource.toml @@ -17,7 +17,3 @@ enabled = true [analyzers.meta] runtime_version = "3.x.x" max_line_length = 120 - -[[transformers]] -name = "autopep8" -enabled = true diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 354b5151b2..6d0283c18c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,7 +5,7 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @ethanwharris @borda @tchaton @justusschock @carmocca @kaushikb11 +* @ethanwharris @borda @tchaton @ananyahjha93 @justusschock @carmocca @kaushikb11 # owners /.github/CODEOWNERS @williamfalcon @@ -17,12 +17,12 @@ /__init__.py @borda @ethanwharris # CI/CD -/.github/workflows/ @borda @ethanwharris +/.github/workflows/ @borda @ethanwharris @ananyahjha93 # configs in root -/*.yml @borda @ethanwharris +/*.yml @borda @ethanwharris @ananyahjha93 # Docs -/docs/ @edenlightning @ethanwharris -/.github/*.md @edenlightning @ethanwharris -/.github/ISSUE_TEMPLATE/*.md @edenlightning @ethanwharris -/docs/source/conf.py @borda @ethanwharris +/docs/ @edenlightning @ethanwharris @ananyahjha93 +/.github/*.md @edenlightning @ethanwharris @ananyahjha93 +/.github/ISSUE_TEMPLATE/*.md @edenlightning @ethanwharris @ananyahjha93 +/docs/source/conf.py @borda @ethanwharris @ananyahjha93 diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 5db03b4fd7..254234c8fd 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -19,32 +19,52 @@ jobs: os: [ubuntu-20.04, macOS-10.15, windows-2019] python-version: [3.6, 3.8] requires: ['minimal', 'latest'] - topic: ['devel'] + topic: [['devel']] include: - os: ubuntu-20.04 python-version: 3.8 requires: 'latest' - topic: 'image' + topic: ['image'] - os: ubuntu-20.04 python-version: 3.8 requires: 'minimal' - topic: 'image' + topic: ['image'] - os: ubuntu-20.04 python-version: 3.8 requires: 'latest' - topic: 'video' + topic: ['image','image_extras'] - os: ubuntu-20.04 python-version: 3.8 requires: 'latest' - topic: 'tabular' + topic: ['video'] - os: ubuntu-20.04 python-version: 3.8 requires: 'latest' - topic: 'text' + topic: ['video','video_extras'] - os: ubuntu-20.04 python-version: 3.8 requires: 'latest' - topic: 'serve' + topic: ['tabular'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['text'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['pointcloud'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['serve'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['graph'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['audio'] # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 @@ -64,7 +84,7 @@ jobs: brew install libomp # https://github.com/pytorch/pytorch/issues/20030 - name: Install graphviz - if: matrix.topic == 'serve' + if: matrix.topic[0] == 'serve' run: | sudo apt-get install graphviz @@ -93,24 +113,32 @@ jobs: uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.topic }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }} + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }} restore-keys: | - ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.topic }}-${{ matrix.requires }}-pip- + ${{ runner.os }}-${{ matrix.python-version }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip- - name: Install dependencies run: | python --version pip --version - pip install '.[${{ matrix.topic }}]' --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install torch>=1.8 + pip install '.[${{ join(matrix.topic,',') }}]' --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip install '.[test]' --pre --upgrade pip list shell: bash - name: Install serve test dependencies - if: matrix.topic == 'serve' + if: matrix.topic[0] == 'serve' run: | pip install '.[all]' --pre --upgrade + - name: Install audio test dependencies + if: matrix.topic[0] == 'audio' + run: | + sudo apt-get install libsndfile1 + pip install matplotlib + pip install '.[audio,image]' --pre --upgrade + - name: Cache datasets uses: actions/cache@v2 with: @@ -120,7 +148,7 @@ jobs: - name: Tests env: - FLASH_TEST_TOPIC: ${{ matrix.topic }} + FLASH_TEST_TOPIC: ${{ join(matrix.topic,',') }} FIFTYONE_DO_NOT_TRACK: true run: | # tox --sitepackages diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 407ad86b3a..1831cf898a 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -23,36 +23,6 @@ jobs: - name: PEP8 run: flake8 . - #format-check-yapf: - # runs-on: ubuntu-20.04 - # steps: - # - uses: actions/checkout@master - # - uses: actions/setup-python@v2 - # with: - # python-version: 3.8 - # - name: Install dependencies - # run: | - # pip install --upgrade pip - # pip install yapf - # pip list - # shell: bash - # - name: yapf - # run: yapf --diff --parallel --recursive . - - #imports-check-isort: - # runs-on: ubuntu-20.04 - # steps: - # - uses: actions/checkout@master - # - uses: actions/setup-python@v2 - # with: - # python-version: 3.8 - # - name: Install isort - # run: | - # pip install isort - # pip list - # - name: isort - # run: isort --check-only . - #typing-check-mypy: # runs-on: ubuntu-20.04 # steps: diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 769450e9ab..d2ae660242 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -7,23 +7,6 @@ on: # Trigger the workflow on push or pull request, but only for the master bran branches: [master] jobs: - sphinx-docs: - runs-on: ubuntu-20.04 - steps: - - uses: actions/checkout@v2 - - uses: ammaraskar/sphinx-action@master - with: - # git is required to clone the docs theme - # before custom requirement are resolved https://github.com/ammaraskar/sphinx-action/issues/16 - pre-build-command: "apt-get update -y && apt-get install -y gcc git pandoc && pip install . && pip install -r ./requirements/docs.txt" - docs-folder: "docs/" - repo-token: "${{ secrets.GITHUB_TOKEN }}" - - uses: actions/upload-artifact@v2 - with: - name: docs-results-${{ github.sha }} - path: docs/build/html/ - - make-docs: runs-on: ubuntu-20.04 diff --git a/.gitignore b/.gitignore index 063c3d52c7..f757f1f042 100644 --- a/.gitignore +++ b/.gitignore @@ -75,6 +75,11 @@ instance/ # Sphinx documentation docs/_build/ +docs/api/ +docs/notebooks/ +docs/source/api/generated/ +docs/source/integrations/generated/ +docs/source/generated/ # PyBuilder target/ @@ -133,8 +138,6 @@ dmypy.json # Pyre type checker .pyre/ -docs/notebooks/ -docs/api/ titanic.csv .vscode .venv @@ -143,6 +146,7 @@ data_folder *.zip flash_notebooks/*.py flash_notebooks/data +/data MNIST* titanic hymenoptera_data @@ -157,3 +161,9 @@ CameraRGB CameraSeg jigsaw_toxic_comments flash_examples/serve/tabular_classification/data +logs/cache/* +flash_examples/data +flash_examples/checkpoints +timit/ +urban8k_images/ +__MACOSX diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a1aafd590..fec61fe332 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,28 +34,46 @@ repos: - id: check-added-large-files - id: detect-private-key + - repo: https://github.com/asottile/pyupgrade + rev: v2.23.0 + hooks: + - id: pyupgrade + args: [--py36-plus] + name: Upgrade code + - repo: https://github.com/PyCQA/isort - rev: 5.9.1 + rev: 5.9.3 hooks: - id: isort name: imports require_serial: false - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.31.0 + - repo: https://github.com/kynan/nbstripout + rev: 0.5.0 hooks: - - id: yapf - name: formatting - language: python - require_serial: false + - id: nbstripout + + - repo: https://github.com/myint/docformatter + rev: v1.4 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] + + - repo: https://github.com/psf/black + rev: 21.7b0 + hooks: + - id: black + name: Format code + + - repo: https://github.com/asottile/blacken-docs + rev: v1.10.0 + hooks: + - id: blacken-docs + args: [ --line-length=120 ] + additional_dependencies: [ black==21.7b0 ] - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 name: PEP8 - - - repo: https://github.com/kynan/nbstripout - rev: 0.5.0 - hooks: - - id: nbstripout diff --git a/CHANGELOG.md b/CHANGELOG.md index 877962446e..431b1e771a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,17 +10,73 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552)) +- Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556)) + +- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) + +- Added support for Semantic Segmentation backbones and heads from `segmentation-models.pytorch` ([#562](https://github.com/PyTorchLightning/lightning-flash/pull/562)) + +- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575)) + +- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566)) + +- Added `PointCloudObjectDetection` Task ([#600](https://github.com/PyTorchLightning/lightning-flash/pull/600)) + +- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73)) + +- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587)) + +- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) + +- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594)) + +- Added a `SpeechRecognition` task for speech to text using Wav2Vec ([#586](https://github.com/PyTorchLightning/lightning-flash/pull/586)) + +- Added Flash Zero, a zero code command line ML platform built with flash ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611)) + +- Added support for `.npy` and `.npz` files to `ImageClassificationData` and `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + +- Added support for `from_csv` to the `AudioClassificationData` ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + +- Added option to pass a `resolver` to the `from_csv` and `from_pandas` methods of `ImageClassificationData`, which is used to resolve filenames given IDs ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + +- Added integration with IceVision for the `ObjectDetector` ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added keypoint detection task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added instance segmentation task ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) + +- Added Torch ORT support to Transformer based tasks ([#667](https://github.com/PyTorchLightning/lightning-flash/pull/667)) + +- Added support for flash zero with the `InstanceSegmentation` and `KeypointDetector` tasks ([#672](https://github.com/PyTorchLightning/lightning-flash/pull/672)) + +- Added support for `in_chans` argument to the flash ResNet to control the expected number of input channels ([#673](https://github.com/PyTorchLightning/lightning-flash/pull/673)) + ### Changed +- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) -### Deprecated +- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) +- Changed the behaviour of the `sampler` argument of the `DataModule` to take a `Sampler` type rather than instantiated object ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) -### Fixed +- Changed arguments to `ObjectDetector`, use `head` instead of `model` and append `_fpn` to the backbone name instead of the `fpn` argument ([#608](https://github.com/PyTorchLightning/lightning-flash/pull/608)) +### Fixed - Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493)) +- Fixed a bug where train and validation metrics weren't being correctly computed ([#559](https://github.com/PyTorchLightning/lightning-flash/pull/559)) + +- Fixed a bug where an uncaught ValueError could be raised when checking if a module is available ([#615](https://github.com/PyTorchLightning/lightning-flash/pull/615)) + +- Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of `torch.jit.isinstance` ([#611](https://github.com/PyTorchLightning/lightning-flash/pull/611)) + +- Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) + +- Fixed a bug where it was not possible to pass no metrics to the `ImageClassifier` or `TestClassifier` ([#660](https://github.com/PyTorchLightning/lightning-flash/pull/660)) + +- Fixed a bug where `drop_last` would be set to True during prediction and testing ([#671](https://github.com/PyTorchLightning/lightning-flash/pull/671)) ## [0.4.0] - 2021-06-22 diff --git a/Makefile b/Makefile index 6fcee001e6..d851e1b53c 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,6 @@ clean: rm -rf $(shell find . -name "mlruns") rm -rf .mypy_cache rm -rf .pytest_cache + rm -rf **/__pycache__ rm -rf ./docs/build rm -rf ./docs/source/**/generated - rm -rf ./docs/source/api diff --git a/README.md b/README.md index a950b6c458..03596edcdb 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@

Installation • - Docs • + DocsAboutPredictionFinetuning • @@ -22,14 +22,13 @@

-[![Stable API](https://img.shields.io/static/v1.svg?label=API&message=stable&color=green)](https://img.shields.io/static/v1.svg?label=API&message=stable&color=green) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/lightning-flash)](https://pypi.org/project/lightning-flash/) [![PyPI Status](https://badge.fury.io/py/lightning-flash.svg)](https://badge.fury.io/py/lightning-flash) [![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ) [![Discourse status](https://img.shields.io/discourse/status?server=https%3A%2F%2Fforums.pytorchlightning.ai)](https://forums.pytorchlightning.ai/) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/PytorchLightning/pytorch-lightning/blob/master/LICENSE) -[![Documentation Status](https://readthedocs.org/projects/lightning-flash/badge/?version=latest)](https://lightning-flash.readthedocs.io/en/latest/?badge=latest) +[![Documentation Status](https://readthedocs.org/projects/lightning-flash/badge/?version=latest)](https://lightning-flash.readthedocs.io/en/stable/?badge=stable) ![CI testing](https://github.com/PyTorchLightning/lightning-flash/workflows/CI%20testing/badge.svg?branch=master&event=push) [![codecov](https://codecov.io/gh/PyTorchLightning/lightning-flash/branch/master/graph/badge.svg?token=oLuUr9q1vt)](https://codecov.io/gh/PyTorchLightning/lightning-flash) @@ -41,8 +40,19 @@ --- + +__Note:__ Flash is currently being tested on real-world use cases and is in active development. Please [open an issue](https://github.com/PyTorchLightning/lightning-flash/issues/new/choose) if you find anything that isn't working as expected. + +--- + ## News -[Read our launch blogpost](https://pytorch-lightning.medium.com/introducing-lightning-flash-the-fastest-way-to-get-started-with-deep-learning-202f196b3b98) + +- Jul 12: Flash Task-a-thon community sprint with 25+ community members +- Jul 1: [Lightning Flash 0.4](https://devblog.pytorchlightning.ai/lightning-flash-0-4-flash-serve-fiftyone-multi-label-text-classification-and-jit-support-97428276c06f) +- Jun 22: [Ushering in the New Age of Video Understanding with PyTorch](https://medium.com/pytorch/ushering-in-the-new-age-of-video-understanding-with-pytorch-1d85078e8015) +- May 24: [Lightning Flash 0.3](https://devblog.pytorchlightning.ai/lightning-flash-0-3-new-tasks-visualization-tools-data-pipeline-and-flash-registry-api-1e236ba9530) +- May 20: [Video Understanding with PyTorch](https://towardsdatascience.com/video-understanding-made-simple-with-pytorch-video-and-lightning-flash-c7d65583c37e) +- Feb 2: [Read our launch blogpost](https://pytorch-lightning.medium.com/introducing-lightning-flash-the-fastest-way-to-get-started-with-deep-learning-202f196b3b98) --- @@ -110,10 +120,12 @@ from flash.text import TranslationTask model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") # 2. Translate a few sentences! -predictions = model.predict([ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", -]) +predictions = model.predict( + [ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + ] +) print(predictions) ``` @@ -128,7 +140,7 @@ model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws. model.serve() ``` -Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building Flash Serve Engine. +Credits to [@rlizzo](https://github.com/rlizzo), [@hhsecond](https://github.com/hhsecond), [@lantiga](https://github.com/lantiga), [@luiscape](https://github.com/luiscape) for building Flash Serve Engine. ### Finetuning @@ -140,7 +152,7 @@ from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Load the data datamodule = ImageClassificationData.from_folders( @@ -168,10 +180,10 @@ Then use the finetuned model: from flash.image import ImageClassifier # load the finetuned model -classifier = ImageClassifier.load_from_checkpoint('image_classification_model.pt') +classifier = ImageClassifier.load_from_checkpoint("image_classification_model.pt") # predict! -predictions = classifier.predict('data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg') +predictions = classifier.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") print(predictions) ``` @@ -191,16 +203,16 @@ from flash.core.data.utils import download_data from flash.image import ImageEmbedder # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Create an ImageEmbedder with resnet50 trained on imagenet. -embedder = ImageEmbedder(backbone="resnet50", embedding_dim=128) +embedder = ImageEmbedder(backbone="resnet50") # 3. Generate an embedding from an image path. -embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') +embeddings = embedder.predict("data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg") # 4. Print embeddings shape -print(embeddings.shape) +print(embeddings[0].shape) ``` @@ -213,11 +225,12 @@ Flash has a [Summarization task](https://lightning-flash.readthedocs.io/en/lates ```python import flash +import torch from flash.core.data.utils import download_data from flash.text import SummarizationData, SummarizationTask # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") # 2. Load the data datamodule = SummarizationData.from_csv( @@ -232,7 +245,7 @@ datamodule = SummarizationData.from_csv( model = SummarizationTask() # 4. Create the trainer. Run once on data -trainer = flash.Trainer(max_epochs=1, gpus=1, precision=16) +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), precision=16) # 5. Fine-tune the model trainer.finetune(model, datamodule=datamodule) @@ -260,13 +273,13 @@ To illustrate, say we want to build a model to predict if a passenger survived o from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash.core.data.utils import download_data -from flash.tabular import TabularClassifier, TabularData +from flash.tabular import TabularClassifier, TabularClassificationData # 1. Download the data -download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') +download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the data -datamodule = TabularData.from_csv( +datamodule = TabularClassificationData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", @@ -318,9 +331,9 @@ download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0. # 2. Load the Data datamodule = ObjectDetectionData.from_coco( - train_folder="data/coco128/images/train2017/", - train_ann_file="data/coco128/annotations/instances_train2017.json", - batch_size=2 + train_folder="data/coco128/images/train2017/", + train_ann_file="data/coco128/annotations/instances_train2017.json", + batch_size=2, ) # 3. Build the model @@ -375,9 +388,7 @@ datamodule = VideoClassificationData.from_folders( ) # 3. Build the model -model = VideoClassifier( - backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False -) +model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False) # 4. Create the trainer trainer = flash.Trainer(max_epochs=3) @@ -410,7 +421,9 @@ from flash.core.data.utils import download_data from flash.image import SemanticSegmentation, SemanticSegmentationData # 1. Download the Data -download_data("https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/") +download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/" +) # 2. Load the Data datamodule = SemanticSegmentationData.from_folders( @@ -444,9 +457,9 @@ python flash_examples/finetuning/semantic_segmentation.py -### Example 7: Style Transfer with Pystiche +### Example 7: Style Transfer with pystiche -Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [Pystiche](https://github.com/pystiche/pystiche). +Flash has a [Style Transfer task](https://lightning-flash.readthedocs.io/en/latest/reference/style_transfer.html) for Neural Style Transfer (NST) with [pystiche](https://pystiche.org).
View example @@ -497,15 +510,10 @@ from torch.utils.data import DataLoader, random_split from torchvision import transforms, datasets # model -model = nn.Sequential( - nn.Flatten(), - nn.Linear(28 * 28, 128), - nn.ReLU(), - nn.Linear(128, 10) -) +model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10)) # data -dataset = datasets.MNIST('./data_folder', download=True, transform=transforms.ToTensor()) +dataset = datasets.MNIST("./data_folder", download=True, transform=transforms.ToTensor()) train, val = random_split(dataset, [55000, 5000]) # task @@ -527,6 +535,7 @@ from torchmetrics import Accuracy from typing import Callable, Mapping, Sequence, Type, Union from flash.core.classification import ClassificationTask + class LinearClassifier(ClassificationTask): def __init__( self, @@ -551,9 +560,9 @@ class LinearClassifier(ClassificationTask): def forward(self, x): return self.linear(x) + classifier = LinearClassifier(128, 10) ... - ``` When you reach the limits of the flexibility provided by Flash, then seamlessly transition to PyTorch Lightning which @@ -577,9 +586,7 @@ download_data( ) # 2. Load the model from a checkpoint and use the FiftyOne serializer -model = ObjectDetector.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/object_detection_model.pt" -) +model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/object_detection_model.pt") model.serializer = FiftyOneDetectionLabels() # 3. Detect the object on the images @@ -600,12 +607,14 @@ The lightning + Flash team is hard at work building more tasks for common deep-l Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ) and/or read our [CONTRIBUTING](https://github.com/PyTorchLightning/lightning-flash/blob/master/.github/CONTRIBUTING.md) guidelines to get help becoming a contributor! ## Community +Flash is maintained by our [core contributors](https://lightning-flash.readthedocs.io/en/latest/governance.html). + For help or questions, join our huge community on [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)! ## Citations -We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors. +We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffe, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors. -Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), and [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/) for the `vision`, `text`, and `tabular` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts). +Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [open3d-ml](https://github.com/intel-isl/Open3D-ML) for pointcloud, [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts). ## License Please observe the Apache 2.0 license that is listed in this repository. In addition diff --git a/docs/source/_static/images/data_serving_flow.png b/docs/source/_static/images/data_serving_flow.png deleted file mode 100644 index 511309e954..0000000000 Binary files a/docs/source/_static/images/data_serving_flow.png and /dev/null differ diff --git a/docs/source/_static/images/inference_server.png b/docs/source/_static/images/inference_server.png deleted file mode 100644 index 219a95cd36..0000000000 Binary files a/docs/source/_static/images/inference_server.png and /dev/null differ diff --git a/docs/source/_static/images/swagger_ui.png b/docs/source/_static/images/swagger_ui.png deleted file mode 100644 index 99a983f23b..0000000000 Binary files a/docs/source/_static/images/swagger_ui.png and /dev/null differ diff --git a/docs/source/_static/main.css b/docs/source/_static/main.css new file mode 100644 index 0000000000..f636f8227c --- /dev/null +++ b/docs/source/_static/main.css @@ -0,0 +1,3 @@ +.longtable col { + width: 50% !important; +} diff --git a/docs/source/_templates/classtemplate.rst b/docs/source/_templates/classtemplate.rst new file mode 100644 index 0000000000..398a0ec07c --- /dev/null +++ b/docs/source/_templates/classtemplate.rst @@ -0,0 +1,14 @@ +.. role:: hidden + :class: hidden-section +.. currentmodule:: {{ module }} + + +{{ name | underline }} + +.. autoclass:: {{ name }} + :members: + + +.. + autogenerated from source/_templates/classtemplate.rst + note it does not have :inherited-members: diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 7f5e2d32db..b1cc0680bb 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -4,7 +4,7 @@ {% block footer %} {{ super() }} {% endblock %} diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst new file mode 100644 index 0000000000..ae6455c6d8 --- /dev/null +++ b/docs/source/api/audio.rst @@ -0,0 +1,43 @@ +########### +flash.audio +########### + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.audio + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.data.AudioClassificationData + ~classification.data.AudioClassificationPreprocess + +Speech Recognition +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~speech_recognition.data.SpeechRecognitionData + ~speech_recognition.model.SpeechRecognition + + speech_recognition.data.SpeechRecognitionPreprocess + speech_recognition.data.SpeechRecognitionBackboneState + speech_recognition.data.SpeechRecognitionPostprocess + speech_recognition.data.SpeechRecognitionCSVDataSource + speech_recognition.data.SpeechRecognitionJSONDataSource + speech_recognition.data.BaseSpeechRecognition + speech_recognition.data.SpeechRecognitionFileDataSource + speech_recognition.data.SpeechRecognitionPathsDataSource + speech_recognition.data.SpeechRecognitionDatasetDataSource + speech_recognition.data.SpeechRecognitionDeserializer diff --git a/docs/source/api/core.rst b/docs/source/api/core.rst new file mode 100644 index 0000000000..9455691b39 --- /dev/null +++ b/docs/source/api/core.rst @@ -0,0 +1,117 @@ +########## +flash.core +########## + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +flash.core.adapter +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.adapter.Adapter + ~flash.core.adapter.AdapterTask + +flash.core.classification +_________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.classification.Classes + ~flash.core.classification.ClassificationSerializer + ~flash.core.classification.ClassificationTask + ~flash.core.classification.FiftyOneLabels + ~flash.core.classification.Labels + ~flash.core.classification.Logits + ~flash.core.classification.PredsClassificationSerializer + ~flash.core.classification.Probabilities + +flash.core.finetuning +_____________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.finetuning.FlashBaseFinetuning + ~flash.core.finetuning.FreezeUnfreeze + ~flash.core.finetuning.NoFreeze + ~flash.core.finetuning.UnfreezeMilestones + +flash.core.integrations.fiftyone +________________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.integrations.fiftyone.utils.visualize + +flash.core.integrations.icevision +_________________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter + ~flash.core.integrations.icevision.transforms.default_transforms + ~flash.core.integrations.icevision.transforms.train_default_transforms + +flash.core.model +________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.model.BenchmarkConvergenceCI + ~flash.core.model.CheckDependenciesMeta + ~flash.core.model.ModuleWrapperBase + ~flash.core.model.DatasetProcessor + ~flash.core.model.Task + +flash.core.registry +___________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.registry.FlashRegistry + +flash.core.optimizers +_____________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.optimizers.LARS + ~flash.core.optimizers.LAMB + ~flash.core.optimizers.LinearWarmupCosineAnnealingLR + +Utilities +_________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.trainer.from_argparse_args + ~flash.core.utilities.apply_func.get_callable_name + ~flash.core.utilities.apply_func.get_callable_dict + ~flash.core.model.predict_context diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst new file mode 100644 index 0000000000..497fd916e9 --- /dev/null +++ b/docs/source/api/data.rst @@ -0,0 +1,177 @@ +############### +flash.core.data +############### + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +flash.core.data.auto_dataset +____________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.auto_dataset.AutoDataset + ~flash.core.data.auto_dataset.BaseAutoDataset + ~flash.core.data.auto_dataset.IterableAutoDataset + +flash.core.data.base_viz +________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.base_viz.BaseVisualization + +flash.core.data.batch +________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.data.batch.default_uncollate + +flash.core.data.callback +________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.callback.BaseDataFetcher + ~flash.core.data.callback.ControlFlow + ~flash.core.data.callback.FlashCallback + +flash.core.data.data_module +___________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.data_module.DataModule + +flash.core.data.data_pipeline +_____________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.data_pipeline.DataPipeline + ~flash.core.data.data_pipeline.DataPipelineState + +flash.core.data.data_source +___________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.data_source.DatasetDataSource + ~flash.core.data.data_source.DataSource + ~flash.core.data.data_source.DefaultDataKeys + ~flash.core.data.data_source.DefaultDataSources + ~flash.core.data.data_source.FiftyOneDataSource + ~flash.core.data.data_source.ImageLabelsMap + ~flash.core.data.data_source.LabelsState + ~flash.core.data.data_source.MockDataset + ~flash.core.data.data_source.NumpyDataSource + ~flash.core.data.data_source.PathsDataSource + ~flash.core.data.data_source.SequenceDataSource + ~flash.core.data.data_source.TensorDataSource + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.data.data_source.has_file_allowed_extension + ~flash.core.data.data_source.has_len + ~flash.core.data.data_source.make_dataset + +flash.core.data.process +_______________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.process.BasePreprocess + ~flash.core.data.process.DefaultPreprocess + ~flash.core.data.process.DeserializerMapping + ~flash.core.data.process.Deserializer + ~flash.core.data.process.Postprocess + ~flash.core.data.process.Preprocess + ~flash.core.data.process.SerializerMapping + ~flash.core.data.process.Serializer + +flash.core.data.properties +__________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.properties.ProcessState + ~flash.core.data.properties.Properties + +flash.core.data.splits +______________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.splits.SplitDataset + +flash.core.data.transforms +__________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.transforms.ApplyToKeys + ~flash.core.data.transforms.KorniaParallelTransforms + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.data.transforms.merge_transforms + ~flash.core.data.transforms.kornia_collate + +flash.core.data.utils +_____________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.utils.CurrentFuncContext + ~flash.core.data.utils.CurrentRunningStageContext + ~flash.core.data.utils.CurrentRunningStageFuncContext + ~flash.core.data.utils.FuncModule + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.data.utils.convert_to_modules + ~flash.core.data.utils.download_data diff --git a/docs/source/api/flash.rst b/docs/source/api/flash.rst new file mode 100644 index 0000000000..06540aad69 --- /dev/null +++ b/docs/source/api/flash.rst @@ -0,0 +1,17 @@ +##### +flash +##### + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~flash.core.data.data_source.DataSource + ~flash.core.data.data_module.DataModule + ~flash.core.data.callback.FlashCallback + ~flash.core.data.process.Preprocess + ~flash.core.data.process.Postprocess + ~flash.core.data.process.Serializer + ~flash.core.model.Task + ~flash.core.trainer.Trainer diff --git a/docs/source/api/graph.rst b/docs/source/api/graph.rst new file mode 100644 index 0000000000..bf94475ab2 --- /dev/null +++ b/docs/source/api/graph.rst @@ -0,0 +1,33 @@ +########### +flash.graph +########### + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.graph + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.model.GraphClassifier + ~classification.data.GraphClassificationData + + classification.data.GraphClassificationPreprocess + +flash.graph.data +________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~data.GraphDatasetDataSource diff --git a/docs/source/api/image.rst b/docs/source/api/image.rst new file mode 100644 index 0000000000..34d44164a8 --- /dev/null +++ b/docs/source/api/image.rst @@ -0,0 +1,147 @@ +########### +flash.image +########### + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.image + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.model.ImageClassifier + ~classification.data.ImageClassificationData + ~classification.data.ImageClassificationPreprocess + + classification.data.MatplotlibVisualization + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: + + classification.transforms.default_transforms + classification.transforms.train_default_transforms + +Object Detection +________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~detection.model.ObjectDetector + ~detection.data.ObjectDetectionData + + detection.data.FiftyOneParser + detection.data.ObjectDetectionFiftyOneDataSource + detection.data.ObjectDetectionPreprocess + detection.serialization.DetectionLabels + detection.serialization.FiftyOneDetectionLabels + +Keypoint Detection +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~keypoint_detection.model.KeypointDetector + ~keypoint_detection.data.KeypointDetectionData + + keypoint_detection.data.KeypointDetectionPreprocess + +Instance Segmentation +_____________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~instance_segmentation.model.InstanceSegmentation + ~instance_segmentation.data.InstanceSegmentationData + + instance_segmentation.data.InstanceSegmentationPreprocess + +Embedding +_________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~embedding.model.ImageEmbedder + +Segmentation +____________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~segmentation.model.SemanticSegmentation + ~segmentation.data.SemanticSegmentationData + ~segmentation.data.SemanticSegmentationPreprocess + + segmentation.data.SegmentationMatplotlibVisualization + segmentation.data.SemanticSegmentationNumpyDataSource + segmentation.data.SemanticSegmentationTensorDataSource + segmentation.data.SemanticSegmentationPathsDataSource + segmentation.data.SemanticSegmentationFiftyOneDataSource + segmentation.data.SemanticSegmentationDeserializer + segmentation.model.SemanticSegmentationPostprocess + segmentation.serialization.FiftyOneSegmentationLabels + segmentation.serialization.SegmentationLabels + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + segmentation.transforms.default_transforms + segmentation.transforms.prepare_target + segmentation.transforms.train_default_transforms + +Style Transfer +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~style_transfer.model.StyleTransfer + ~style_transfer.data.StyleTransferData + ~style_transfer.data.StyleTransferPreprocess + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~style_transfer.utils.raise_not_supported + +flash.image.data +________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~data.ImageDeserializer + ~data.ImageFiftyOneDataSource + ~data.ImageNumpyDataSource + ~data.ImagePathsDataSource + ~data.ImageTensorDataSource diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst new file mode 100644 index 0000000000..d3c7b94797 --- /dev/null +++ b/docs/source/api/pointcloud.rst @@ -0,0 +1,40 @@ +################ +flash.pointcloud +################ + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.pointcloud + +Segmentation +____________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~segmentation.model.PointCloudSegmentation + ~segmentation.data.PointCloudSegmentationData + + segmentation.data.PointCloudSegmentationPreprocess + segmentation.data.PointCloudSegmentationFoldersDataSource + segmentation.data.PointCloudSegmentationDatasetDataSource + +Object Detection +________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~detection.model.PointCloudObjectDetector + ~detection.data.PointCloudObjectDetectorData + + detection.data.PointCloudObjectDetectorPreprocess + detection.data.PointCloudObjectDetectorFoldersDataSource + detection.data.PointCloudObjectDetectorDatasetDataSource diff --git a/docs/source/api/serve.rst b/docs/source/api/serve.rst new file mode 100644 index 0000000000..66406c6242 --- /dev/null +++ b/docs/source/api/serve.rst @@ -0,0 +1,14 @@ +################ +flash.core.serve +################ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate + + ~flash.core.serve.component.ModelComponent + ~flash.core.serve.composition.Composition + ~flash.core.serve.core.Endpoint + ~flash.core.serve.core.Servable + ~flash.core.serve.decorators.expose diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst new file mode 100644 index 0000000000..0752a5ca52 --- /dev/null +++ b/docs/source/api/tabular.rst @@ -0,0 +1,46 @@ +############# +flash.tabular +############# + +.. contents:: + :depth: 2 + :local: + :backlinks: top + +.. currentmodule:: flash.tabular + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.model.TabularClassifier + ~classification.data.TabularClassificationData + +Regression +__________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~regression.data.TabularRegressionData + +flash.tabular.data +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~data.TabularData + ~data.TabularDataFrameDataSource + ~data.TabularCSVDataSource + ~data.TabularDeserializer + ~data.TabularPreprocess + ~data.TabularPostprocess diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst new file mode 100644 index 0000000000..f9177eec85 --- /dev/null +++ b/docs/source/api/text.rst @@ -0,0 +1,93 @@ +########## +flash.text +########## + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.text + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.model.TextClassifier + ~classification.data.TextClassificationData + + classification.data.TextClassificationPostprocess + classification.data.TextClassificationPreprocess + classification.data.TextCSVDataSource + classification.data.TextDataSource + classification.data.TextDeserializer + classification.data.TextFileDataSource + classification.data.TextJSONDataSource + classification.data.TextSentencesDataSource + +Question Answering +__________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~seq2seq.question_answering.model.QuestionAnsweringTask + ~seq2seq.question_answering.data.QuestionAnsweringData + + seq2seq.question_answering.data.QuestionAnsweringPreprocess + +Summarization +_____________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~seq2seq.summarization.model.SummarizationTask + ~seq2seq.summarization.data.SummarizationData + + seq2seq.summarization.data.SummarizationPreprocess + +Translation +___________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~seq2seq.translation.model.TranslationTask + ~seq2seq.translation.data.TranslationData + + seq2seq.translation.data.TranslationPreprocess + +General Seq2Seq +_______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~seq2seq.core.model.Seq2SeqTask + ~seq2seq.core.data.Seq2SeqData + ~seq2seq.core.finetuning.Seq2SeqFreezeEmbeddings + + seq2seq.core.data.Seq2SeqBackboneState + seq2seq.core.data.Seq2SeqCSVDataSource + seq2seq.core.data.Seq2SeqDataSource + seq2seq.core.data.Seq2SeqFileDataSource + seq2seq.core.data.Seq2SeqJSONDataSource + seq2seq.core.data.Seq2SeqPostprocess + seq2seq.core.data.Seq2SeqPreprocess + seq2seq.core.data.Seq2SeqSentencesDataSource + seq2seq.core.metrics.BLEUScore + seq2seq.core.metrics.RougeBatchAggregator + seq2seq.core.metrics.RougeMetric diff --git a/docs/source/api/video.rst b/docs/source/api/video.rst new file mode 100644 index 0000000000..ade63234ca --- /dev/null +++ b/docs/source/api/video.rst @@ -0,0 +1,27 @@ +########### +flash.video +########### + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.video + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.model.VideoClassifier + ~classification.data.VideoClassificationData + + classification.data.BaseVideoClassification + classification.data.VideoClassificationFiftyOneDataSource + classification.data.VideoClassificationPathsDataSource + classification.data.VideoClassificationPreprocess + classification.model.VideoClassifierFinetuning diff --git a/docs/source/code/core.rst b/docs/source/code/core.rst deleted file mode 100644 index d7475c9491..0000000000 --- a/docs/source/code/core.rst +++ /dev/null @@ -1,36 +0,0 @@ -########## -flash.core -########## - -.. contents:: - :depth: 2 - :local: - :backlinks: top - -Models and Backbones -____________________ - -The Task -======== - -.. automodule:: flash.core.model - -.. autoclass:: flash.core.classification.ClassificationTask - -Fitting and Finetuning -______________________ - -Trainer -======= - -.. automodule:: flash.core.trainer - -Finetuning Callbacks -==================== - -.. automodule:: flash.core.finetuning - -Registry -________ - -.. automodule:: flash.core.registry diff --git a/docs/source/code/data.rst b/docs/source/code/data.rst deleted file mode 100644 index 3d4d85aa9d..0000000000 --- a/docs/source/code/data.rst +++ /dev/null @@ -1,85 +0,0 @@ -############### -flash.core.data -############### - -.. contents:: - :depth: 2 - :local: - :backlinks: top - -Data Loading -____________ - -Data Module -=========== - -.. automodule:: flash.core.data.data_module - -Data Sources -============ - -.. automodule:: flash.core.data.data_source - -Data Processing -_______________ - -Data Pipeline -============= - -.. automodule:: flash.core.data.data_pipeline - -Data Pipeline Components -======================== - -.. automodule:: flash.core.data.properties - -.. automodule:: flash.core.data.process - -Transforms -__________ - -.. currentmodule:: flash.core.data.transforms - -Helpers -======= - -ApplyToKeys ------------ - -.. autoclass:: ApplyToKeys - -merge_transforms ----------------- - -.. autofunction:: merge_transforms - -Kornia -====== - -KorniaParallelTransforms ------------------------- - -.. autoclass:: KorniaParallelTransforms - -kornia_collate --------------- - -.. autofunction:: kornia_collate - -Callbacks and Visualizations -____________________________ - -.. automodule:: flash.core.data.base_viz - -.. automodule:: flash.core.data.callback - -Utilities -_________ - -.. automodule:: flash.core.data.auto_dataset - -.. automodule:: flash.core.data.batch - -.. automodule:: flash.core.data.splits - -.. automodule:: flash.core.data.utils diff --git a/docs/source/code/image.rst b/docs/source/code/image.rst deleted file mode 100644 index 969963ae23..0000000000 --- a/docs/source/code/image.rst +++ /dev/null @@ -1,86 +0,0 @@ -########### -flash.image -########### - -.. contents:: - :depth: 1 - :local: - :backlinks: top - -Classification -______________ - -Data -==== - -.. automodule:: flash.image.classification.data - -.. automodule:: flash.image.classification.transforms - -Task -==== - -.. automodule:: flash.image.classification.model - -Detection -_________ - -Data -==== - -.. automodule:: flash.image.detection.data - -.. automodule:: flash.image.detection.transforms - -Task -==== - -.. automodule:: flash.image.detection.model - -Finetuning -========== - -.. automodule:: flash.image.detection.finetuning - -Embedding -_________ - -Task -==== - -.. automodule:: flash.image.embedding.model - -Segmentation -____________ - -Data -==== - -.. automodule:: flash.image.segmentation.data - -.. automodule:: flash.image.segmentation.transforms - -.. automodule:: flash.image.segmentation.serialization - -Task -==== - -.. automodule:: flash.image.segmentation.model - -Style Transfer -______________ - -Data -==== - -.. automodule:: flash.image.style_transfer.data - -Task -==== - -.. automodule:: flash.image.style_transfer.model - -General -_______ - -.. automodule:: flash.image.data diff --git a/docs/source/code/tabular.rst b/docs/source/code/tabular.rst deleted file mode 100644 index 5e8d0caffd..0000000000 --- a/docs/source/code/tabular.rst +++ /dev/null @@ -1,21 +0,0 @@ -############# -flash.tabular -############# - -.. contents:: - :depth: 1 - :local: - :backlinks: top - -Classification -______________ - -Data -==== - -.. automodule:: flash.tabular.classification.data - -Task -==== - -.. automodule:: flash.tabular.classification.model diff --git a/docs/source/code/text.rst b/docs/source/code/text.rst deleted file mode 100644 index 0a23bfbe91..0000000000 --- a/docs/source/code/text.rst +++ /dev/null @@ -1,81 +0,0 @@ -########## -flash.text -########## - -.. contents:: - :depth: 1 - :local: - :backlinks: top - -Classification -______________ - -Data -==== - -.. automodule:: flash.text.classification.data - -Task -==== - -.. automodule:: flash.text.classification.model - -Seq2Seq -_______ - -General -======= - -Data -**** - -.. automodule:: flash.text.seq2seq.core.data - -Task -**** - -.. automodule:: flash.text.seq2seq.core.model - -Finetuning -********** - -.. automodule:: flash.text.seq2seq.core.finetuning - -Summarization -============= - -Data -**** - -.. automodule:: flash.text.seq2seq.summarization.data - :members: SummarizationData - -Task -**** - -.. automodule:: flash.text.seq2seq.summarization.model - -Metric -****** - -.. automodule:: flash.text.seq2seq.summarization.metric - -.. automodule:: flash.text.seq2seq.summarization.utils - -Translation -=========== - -Data -**** - -.. automodule:: flash.text.seq2seq.translation.data - -Task -**** - -.. automodule:: flash.text.seq2seq.translation.model - -Metric -****** - -.. automodule:: flash.text.seq2seq.translation.metric diff --git a/docs/source/code/video.rst b/docs/source/code/video.rst deleted file mode 100644 index 471b11fb7a..0000000000 --- a/docs/source/code/video.rst +++ /dev/null @@ -1,21 +0,0 @@ -########### -flash.video -########### - -.. contents:: - :depth: 1 - :local: - :backlinks: top - -Classification -______________ - -Data -==== - -.. automodule:: flash.video.classification.data - -Task -==== - -.. automodule:: flash.video.classification.model diff --git a/docs/source/common/finetuning_example.rst b/docs/source/common/finetuning_example.rst index 23d56ddf3b..b45b0cfd97 100644 --- a/docs/source/common/finetuning_example.rst +++ b/docs/source/common/finetuning_example.rst @@ -35,7 +35,7 @@ Here's an example of finetuning. model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 3. Create the trainer (run one epoch for demo) - trainer = flash.Trainer(max_epochs=1) + trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) # 4. Finetune the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") @@ -58,7 +58,12 @@ Once you've finetuned, use the model to predict: # Serialize predictions as labels, automatically inferred from the training data in part 2. model.serializer = Labels() - predictions = model.predict(["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg"]) + predictions = model.predict( + [ + "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", + "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg", + ] + ) print(predictions) We get the following output: @@ -86,4 +91,4 @@ Or you can use the saved model for prediction anywhere you want! # load finetuned checkpoint model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") - predictions = model.predict('path/to/your/own/image.png') + predictions = model.predict("path/to/your/own/image.png") diff --git a/docs/source/common/training_example.rst b/docs/source/common/training_example.rst index c936f47b7f..9a015cda65 100644 --- a/docs/source/common/training_example.rst +++ b/docs/source/common/training_example.rst @@ -23,7 +23,7 @@ Here's an example: seed_everything(42) # 1. Download and organize the data - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", @@ -35,7 +35,7 @@ Here's an example: model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False) # 3. Create the trainer (run one epoch for demo) - trainer = flash.Trainer(max_epochs=1) + trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) # 4. Train the model trainer.fit(model, datamodule=datamodule) diff --git a/docs/source/conf.py b/docs/source/conf.py index c295154a58..8374dc8bb9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,18 +10,21 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import glob import os +import shutil import sys from importlib.util import module_from_spec, spec_from_file_location import pt_lightning_sphinx_theme _PATH_HERE = os.path.abspath(os.path.dirname(__file__)) -_PATH_ROOT = os.path.join(_PATH_HERE, '..', '..') +_PATH_ROOT = os.path.join(_PATH_HERE, "..", "..") sys.path.insert(0, os.path.abspath(_PATH_ROOT)) try: from flash import __about__ as about + from flash.core.utilities import providers except ModuleNotFoundError: @@ -32,10 +35,11 @@ def _load_py_module(fname, pkg="flash"): return py about = _load_py_module("__about__.py") + providers = _load_py_module("flash/core/utilities/providers.py") -SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True)) +SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True)) -html_favicon = '_static/images/icon.svg' +html_favicon = "_static/images/icon.svg" # -- Project information ----------------------------------------------------- @@ -43,28 +47,67 @@ def _load_py_module(fname, pkg="flash"): copyright = "2020-2021, PyTorch Lightning" author = "PyTorch Lightning" +# -- Project documents ------------------------------------------------------- + + +def _transform_changelog(path_in: str, path_out: str) -> None: + with open(path_in) as fp: + chlog_lines = fp.readlines() + # enrich short subsub-titles to be unique + chlog_ver = "" + for i, ln in enumerate(chlog_lines): + if ln.startswith("## "): + chlog_ver = ln[2:].split("-")[0].strip() + elif ln.startswith("### "): + ln = ln.replace("###", f"### {chlog_ver} -") + chlog_lines[i] = ln + with open(path_out, "w") as fp: + fp.writelines(chlog_lines) + + +generated_dir = os.path.join(_PATH_HERE, "generated") + +os.makedirs(generated_dir, exist_ok=True) +# copy all documents from GH templates like contribution guide +for md in glob.glob(os.path.join(_PATH_ROOT, ".github", "*.md")): + shutil.copy(md, os.path.join(generated_dir, os.path.basename(md))) +# copy also the changelog +_transform_changelog(os.path.join(_PATH_ROOT, "CHANGELOG.md"), os.path.join(generated_dir, "CHANGELOG.md")) + +# -- Generate providers ------------------------------------------------------ + +lines = [] +for provider in providers.PROVIDERS: + lines.append(f"- {str(provider)}\n") + +generated_dir = os.path.join("integrations", "generated") +os.makedirs(generated_dir, exist_ok=True) + +with open(os.path.join(generated_dir, "providers.rst"), "w") as f: + f.writelines(sorted(lines, key=str.casefold)) + # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - # 'sphinx.ext.todo', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", # 'sphinx.ext.coverage', - 'sphinx.ext.viewcode', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'sphinx.ext.imgmath', - 'recommonmark', + "sphinx.ext.viewcode", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.imgmath", + "recommonmark", # 'sphinx.ext.autosectionlabel', # 'nbsphinx', # it seems some sphinx issue - 'sphinx_autodoc_typehints', - 'sphinx_copybutton', - 'sphinx_paramlinks', - 'sphinx_togglebutton', + "sphinx_autodoc_typehints", + "sphinx_copybutton", + "sphinx_paramlinks", + "sphinx_togglebutton", ] # autodoc: Default to members and undoc-members @@ -79,7 +122,7 @@ def _load_py_module(fname, pkg="flash"): # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns = ["generated/PULL_REQUEST_TEMPLATE.md"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: @@ -114,8 +157,8 @@ def _load_py_module(fname, pkg="flash"): # documentation. html_theme_options = { - 'pytorch_project': 'https://pytorchlightning.ai', - 'canonical_url': about.__docs_url__, + "pytorch_project": "https://pytorchlightning.ai", + "canonical_url": about.__docs_url__, "collapse_navigation": False, "display_version": True, "logo_only": False, @@ -132,19 +175,20 @@ def _load_py_module(fname, pkg="flash"): def setup(app): # this is for hiding doctest decoration, # see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/ - app.add_js_file('copybutton.js') + app.add_js_file("copybutton.js") + app.add_css_file("main.css") # Ignoring Third-party packages # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule def _package_list_from_file(pfile): assert os.path.isfile(pfile) - with open(pfile, 'r') as fp: + with open(pfile) as fp: lines = fp.readlines() list_pkgs = [] for ln in lines: - found = [ln.index(ch) for ch in list(',=<>#@') if ch in ln] - pkg = ln[:min(found)] if found else ln + found = [ln.index(ch) for ch in list(",=<>#@") if ch in ln] + pkg = ln[: min(found)] if found else ln if pkg.strip(): list_pkgs.append(pkg.strip()) return list_pkgs @@ -152,26 +196,26 @@ def _package_list_from_file(pfile): # define mapping from PyPI names to python imports PACKAGE_MAPPING = { - 'pytorch-lightning': 'pytorch_lightning', - 'scikit-learn': 'sklearn', - 'Pillow': 'PIL', - 'PyYAML': 'yaml', - 'rouge-score': 'rouge_score', - 'lightning-bolts': 'pl_bolts', - 'pytorch-tabnet': 'pytorch_tabnet', - 'pyDeprecate': 'deprecate', + "pytorch-lightning": "pytorch_lightning", + "scikit-learn": "sklearn", + "Pillow": "PIL", + "PyYAML": "yaml", + "rouge-score": "rouge_score", + "lightning-bolts": "pl_bolts", + "pytorch-tabnet": "pytorch_tabnet", + "pyDeprecate": "deprecate", } -MOCK_PACKAGES = [] +MOCK_PACKAGES = ["numpy", "PyYAML", "tqdm"] if SPHINX_MOCK_REQUIREMENTS: # mock also base packages when we are on RTD since we don't install them there - MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, 'requirements.txt')) + MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt")) # replace PyPI packages by importing ones MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES] autodoc_mock_imports = MOCK_PACKAGES # only run doctests marked with a ".. doctest::" directive -doctest_test_doctest_blocks = '' +doctest_test_doctest_blocks = "" doctest_global_setup = """ import torch import pytorch_lightning as pl diff --git a/docs/source/custom_task.rst b/docs/source/custom_task.rst index 937d5a1674..0bd374deea 100644 --- a/docs/source/custom_task.rst +++ b/docs/source/custom_task.rst @@ -55,7 +55,6 @@ It's best practice in flash for the data to be provide as a dictionary which map .. testcode:: custom_task class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.2, metrics=None): # what kind of model do we want? model = torch.nn.Linear(num_inputs, 1) @@ -149,7 +148,6 @@ generated ``dataset``. .. testcode:: custom_task class NumpyDataSource(DataSource[Tuple[ND, ND]]): - def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: if self.training: dataset.num_inputs = data[0].shape[1] @@ -191,7 +189,6 @@ The recommended way to define a custom :class:`~flash.core.data.process.Preproce .. testcode:: custom_task class NumpyPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -282,7 +279,9 @@ supplying the task itself, and the associated data: model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs) - trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False) + trainer = flash.Trainer( + max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count() + ) trainer.fit(model, datamodule=datamodule) @@ -299,13 +298,15 @@ With a trained model we can now perform inference. Here we will use a few exampl .. testcode:: custom_task - predict_data = np.array([ - [ 0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], - [ 0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], - [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], - [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], - ]) + predict_data = np.array( + [ + [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], + [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], + [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], + [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], + [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], + ] + ) predictions = model.predict(predict_data) print(predictions) diff --git a/docs/source/general/callback.rst b/docs/source/general/callback.rst deleted file mode 100644 index 504d499f41..0000000000 --- a/docs/source/general/callback.rst +++ /dev/null @@ -1,23 +0,0 @@ -######## -Callback -######## - -.. _callback: - -************** -Flash Callback -************** - -:class:`~flash.core.data.callback.FlashCallback` is an extension of :class:`pytorch_lightning.callbacks.Callback`. - -A callback is a self-contained program that can be reused across projects. - -Flash and Lightning have a callback system to execute callbacks when needed. - -Callbacks should capture any NON-ESSENTIAL logic that is NOT required for your lightning module to run. - -Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer. - -Example:: - - trainer = Trainer(callbacks=[MyCustomCallback()]) diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index f557ce466e..8e815c5a83 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -21,25 +21,28 @@ Here are common terms you need to be familiar with: * - Term - Definition + * - :class:`~flash.core.data.process.Deserializer` + - The :class:`~flash.core.data.process.Deserializer` provides a single :meth:`~flash.core.data.process.Deserializer.deserialize` method. * - :class:`~flash.core.data.data_module.DataModule` - The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders. * - :class:`~flash.core.data.data_pipeline.DataPipeline` - - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage: :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects. + - The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.Deserializer`, :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects. * - :class:`~flash.core.data.data_source.DataSource` - The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names). * - :class:`~flash.core.data.process.Preprocess` - The :class:`~flash.core.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic. These hooks (such as :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed. - The :class:`~flash.core.data.process.Preprocess` hooks can be either overriden directly or provided as a dictionary of transforms (mapping hook name to callable transform). + The :class:`~flash.core.data.process.Preprocess` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform). * - :class:`~flash.core.data.process.Postprocess` - The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic. The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export. * - :class:`~flash.core.data.process.Serializer` - - The :class:`~flash.core.data.process.Serializer` provides a single ``serialize`` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction. + - The :class:`~flash.core.data.process.Serializer` provides a single :meth:`~flash.core.data.process.Serializer.serialize` method that is used to convert model outputs (after the :class:`~flash.core.data.process.Postprocess`) to the desired output format during prediction. + ******************************************* -How to use out-of-the-box flashdatamodules +How to use out-of-the-box Flash DataModules ******************************************* Flash provides several DataModules with helpers functions. @@ -49,14 +52,14 @@ Check out the :ref:`image_classification` section (or the sections for any of ou Data Processing *************** -Currently, it is common practice to implement a :class:`pytorch.utils.data.Dataset` -and provide it to a :class:`pytorch.utils.data.DataLoader`. +Currently, it is common practice to implement a :class:`torch.utils.data.Dataset` +and provide it to a :class:`torch.utils.data.DataLoader`. However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment. Usually, extra processing logic should be added to bridge the gap between training data and raw data. The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms. -The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilites, etc.). +The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilities, etc.). By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), Flash gives the user much more granular control over their data processing flow. @@ -75,15 +78,14 @@ hooks by adding ``train``, ``val``, ``test`` or ``predict``. Check out :class:`~flash.core.data.process.Preprocess` for some examples. ************************************* -How to customize existing datamodules +How to customize existing DataModules ************************************* Any Flash :class:`~flash.core.data.data_module.DataModule` can be created directly from datasets using the :meth:`~flash.core.data.data_module.DataModule.from_datasets` like this: .. code-block:: python - from flash import Trainer - from flash.core.data.data_module import DataModule + from flash import DataModule, Trainer data_module = DataModule.from_datasets(train_dataset=MyDataset()) trainer = Trainer() @@ -95,6 +97,10 @@ In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule` Flash :class:`~flash.core.data.auto_dataset.AutoDataset` instances are created from the :class:`~flash.core.data.data_source.DataSource` for train, val, test, and predict. The :class:`~flash.core.data.data_module.DataModule` populates the ``DataLoader`` for each stage with the corresponding :class:`~flash.core.data.auto_dataset.AutoDataset`. +************************************** +Customize preprocessing of DataModules +************************************** + The :class:`~flash.core.data.process.Preprocess` contains the processing logic related to a given task. Each :class:`~flash.core.data.process.Preprocess` provides some default transforms through the :meth:`~flash.core.data.process.Preprocess.default_transforms` method. Users can easily override these by providing their own transforms to the :class:`~flash.core.data.data_module.DataModule`. @@ -105,9 +111,7 @@ Here's an example: from flash.core.data.transforms import ApplyToKeys from flash.image import ImageClassificationData, ImageClassifier - transform = { - "to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform) - } + transform = {"to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform)} datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", @@ -125,12 +129,13 @@ Alternatively, the user may directly override the hooks for their needs like thi from typing import Any, Dict from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationPreprocess - class CustomImageClassificationPreprocess(ImageClassificationPreprocess): + class CustomImageClassificationPreprocess(ImageClassificationPreprocess): def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]: sample["input"] = my_to_tensor_transform(sample["input"]) return sample + datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", @@ -139,16 +144,16 @@ Alternatively, the user may directly override the hooks for their needs like thi ) -****************************** -Custom Preprocess + Datamodule -****************************** +***************************************** +Create your own Preprocess and DataModule +***************************************** The example below shows a very simple ``ImageClassificationPreprocess`` with a single ``ImageClassificationFoldersDataSource`` and an ``ImageClassificationDataModule``. 1. User-Facing API design _________________________ -Designing an easy to use API is key. This is the first and most important step. +Designing an easy-to-use API is key. This is the first and most important step. We want the ``ImageClassificationDataModule`` to generate a dataset from folders of images arranged in this way. Example:: @@ -189,26 +194,33 @@ Here's the full ``ImageClassificationFoldersDataSource``: from typing import Any, Dict from flash.core.data.data_source import DataSource, DefaultDataKeys - class ImageClassificationFoldersDataSource(DataSource): + class ImageClassificationFoldersDataSource(DataSource): def load_data(self, folder: str, dataset: Any) -> Iterable: # The dataset is optional but can be useful to save some metadata. - # metadata contains the image path and its corresponding label with the following structure: + # `metadata` contains the image path and its corresponding label + # with the following structure: # [(image_path_1, label_1), ... (image_path_n, label_n)]. metadata = make_dataset(folder) - # for the train ``AutoDataset``, we want to store the ``num_classes``. + # for the train `AutoDataset`, we want to store the `num_classes`. if self.training: dataset.num_classes = len(np.unique([m[1] for m in metadata])) - return [{DefaultDataKeys.INPUT: file, DefaultDataKeys.TARGET: target} for file, target in metadata] + return [ + { + DefaultDataKeys.INPUT: file, + DefaultDataKeys.TARGET: target, + } + for file, target in metadata + ] def predict_load_data(self, predict_folder: str) -> Iterable: # This returns [image_path_1, ... image_path_m]. return [{DefaultDataKeys.INPUT: file} for file in os.listdir(folder)] - def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any] + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample[DefaultDataKeys.INPUT] = Image.open(sample[DefaultDataKeys.INPUT]) return sample @@ -226,9 +238,8 @@ Next, implement your custom ``ImageClassificationPreprocess`` with some default from flash.core.data.process import Preprocess import torchvision.transforms.functional as T - # Subclass ``Preprocess`` + # Subclass `Preprocess` class ImageClassificationPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -255,9 +266,7 @@ Next, implement your custom ``ImageClassificationPreprocess`` with some default return cls(**state_dict) def default_transforms(self) -> Dict[str, Callable]: - return { - "to_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.to_tensor) - } + return {"to_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.to_tensor)} 4. The DataModule _________________ @@ -268,11 +277,12 @@ All we need to do is attach our :class:`~flash.core.data.process.Preprocess` cla .. code-block:: python - from flash.core.data.data_module import DataModule + from flash import DataModule + class ImageClassificationDataModule(DataModule): - # Set ``preprocess_cls`` with your custom ``preprocess``. + # Set `preprocess_cls` with your custom `Preprocess`. preprocess_cls = ImageClassificationPreprocess @@ -283,24 +293,27 @@ How it works behind the scenes DataSource __________ -.. note:: The ``load_data`` and ``load_sample`` will be used to generate an AutoDataset object. +.. note:: + The :meth:`~flash.core.data.data_source.DataSource.load_data` and + :meth:`~flash.core.data.data_source.DataSource.load_sample` will be used to generate an + :class:`~flash.core.data.auto_dataset.AutoDataset` object. -Here is the ``AutoDataset`` pseudo-code. +Here is the :class:`~flash.core.data.auto_dataset.AutoDataset` pseudo-code. -Example:: +.. code-block:: python - class AutoDataset + class AutoDataset: def __init__( self, - data: List[Any], # The result of a call to DataSource.load_data + data: List[Any], # output of `DataSource.load_data` data_source: DataSource, running_stage: RunningStage, - ) -> None: + ): self.data = data self.data_source = data_source - def __getitem__(self, index): + def __getitem__(self, index: int): return self.data_source.load_sample(self.data[index]) def __len__(self): @@ -311,8 +324,12 @@ __________ .. note:: - The ``pre_tensor_transform``, ``to_tensor_transform``, ``post_tensor_transform``, ``collate``, - ``per_batch_transform`` are injected as the ``collate_fn`` function of the DataLoader. + The :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`, + :meth:`~flash.core.data.process.Preprocess.to_tensor_transform`, + :meth:`~flash.core.data.process.Preprocess.post_tensor_transform`, + :meth:`~flash.core.data.process.Preprocess.collate`, + :meth:`~flash.core.data.process.Preprocess.per_batch_transform` are injected as the + :paramref:`torch.utils.data.DataLoader.collate_fn` function of the DataLoader. Here is the pseudo code using the preprocess hooks name. Flash takes care of calling the right hooks for each stage. @@ -385,7 +402,7 @@ Here is the pseudo-code: Example:: - # This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor` + # This will be wrapped into a :class:`~flash.core.data.batch._Postprocessor` def uncollate_fn(batch: Any) -> Any: batch = per_batch_transform(batch) diff --git a/docs/source/general/finetuning.rst b/docs/source/general/finetuning.rst index e10dd7eeee..46e48ae974 100644 --- a/docs/source/general/finetuning.rst +++ b/docs/source/general/finetuning.rst @@ -52,7 +52,7 @@ Finetune strategies from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") datamodule = ImageClassificationData.from_files( train_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"], @@ -104,11 +104,6 @@ The freeze strategy keeps the backbone frozen throughout. trainer.finetune(model, datamodule, strategy="freeze") -.. testoutput:: strategies - :hide: - - ... - The pseudocode looks like: .. code-block:: python @@ -139,11 +134,6 @@ By default, in this strategy the backbone is frozen for 5 epochs then unfrozen: trainer.finetune(model, datamodule, strategy="freeze_unfreeze") -.. testoutput:: strategies - :hide: - - ... - Or we can customize it unfreeze the backbone after a different epoch. For example, to unfreeze after epoch 7: @@ -153,11 +143,6 @@ For example, to unfreeze after epoch 7: trainer.finetune(model, datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=7)) -.. testoutput:: strategies - :hide: - - ... - Under the hood, the pseudocode looks like: .. code-block:: python @@ -193,11 +178,6 @@ Here's an example where: trainer.finetune(model, datamodule, strategy=UnfreezeMilestones(unfreeze_milestones=(3, 8), num_layers=2)) -.. testoutput:: strategies - :hide: - - ... - Under the hood, the pseudocode looks like: .. code-block:: python @@ -231,14 +211,13 @@ For even more customization, create your own finetuning callback. Learn more abo # Create a finetuning callback class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning): - def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True): # this will set self.attr_names as ["backbone"] super().__init__("backbone", train_bn) self._unfreeze_epoch = unfreeze_epoch def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx): - # unfreeze any module you want by overriding this function + # unfreeze any module you want by overriding this function # When ``current_epoch`` is 5, backbone will start to be trained. if current_epoch == self._unfreeze_epoch: @@ -247,10 +226,6 @@ For even more customization, create your own finetuning callback. Learn more abo optimizer, ) + # Pass the callback to trainer.finetune trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5)) - -.. testoutput:: strategies - :hide: - - ... diff --git a/docs/source/general/flash_zero.rst b/docs/source/general/flash_zero.rst new file mode 100644 index 0000000000..da3f73cbb3 --- /dev/null +++ b/docs/source/general/flash_zero.rst @@ -0,0 +1,56 @@ +.. _flash_zero: + +********** +Flash Zero +********** + +Flash Zero is a zero-code machine learning platform built directly into lightning-flash. +To get started and view the available tasks, run: + +.. code-block:: bash + + flash --help + +Customize Trainer and Model arguments +_____________________________________ + +Flash Zero is built on top of the +`lightning CLI `_, so the trainer and +model arguments can be configured either from the command line or from a config file. +For example, to run the image classifier for 10 epochs with a `resnet50` backbone you can use: + +.. code-block:: bash + + flash image_classification --trainer.max_epochs 10 --model.backbone resnet50 + +To view all of the available options for a task, run: + +.. code-block:: bash + + flash image_classification --help + +Using Custom Data +_________________ + +Flash Zero works with your own data through subcommands. The available subcommands for each task are given at the bottom +of their help pages (e.g. when running :code:`flash image-classification --help`). You can then use the required +subcommand to train on your own data. Let's look at an example using the Hymenoptera data from the +:ref:`image_classification` guide. First, download and unzip your data: + +.. code-block:: bash + + curl https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip -o hymenoptera_data + unzip hymenoptera_data.zip + +Now train with Flash Zero: + +.. code-block:: bash + + flash image_classification from_folders --train_folder ./hymenoptera_data/train + +You can view the help page for each subcommand. For example, to view the options for training an image classifier from +folders, you can run: + +.. code-block:: bash + + flash image_classification from_folders --help diff --git a/docs/source/general/jit.rst b/docs/source/general/jit.rst index a0d80f7c51..bce94fcdde 100644 --- a/docs/source/general/jit.rst +++ b/docs/source/general/jit.rst @@ -28,7 +28,7 @@ This table gives a breakdown of the supported features. - Yes - Yes * - :class:`~flash.image.segmentation.model.SemanticSegmentation` - - Yes + - No - Yes - Yes * - :class:`~flash.image.style_transfer.model.StyleTransfer` diff --git a/docs/source/general/predictions.rst b/docs/source/general/predictions.rst index e2b62f6e41..4bd260db99 100644 --- a/docs/source/general/predictions.rst +++ b/docs/source/general/predictions.rst @@ -15,19 +15,19 @@ You can pass in a sample of data (image file path, a string of text, etc) to the .. code-block:: python - from flash.core.data.utils import download_data - from flash.image import ImageClassifier + from flash.core.data.utils import download_data + from flash.image import ImageClassifier - # 1. Download the data set - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') + # 1. Download the data set + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") - # 2. Load the model from a checkpoint - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + # 2. Load the model from a checkpoint + model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") - # 3. Predict whether the image contains an ant or a bee - predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") - print(predictions) + # 3. Predict whether the image contains an ant or a bee + predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") + print(predictions) @@ -36,91 +36,45 @@ Predict on a csv file .. code-block:: python - from flash.core.data.utils import download_data - from flash.tabular import TabularClassifier + from flash.core.data.utils import download_data + from flash.tabular import TabularClassifier - # 1. Download the data - download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") - # 2. Load the model from a checkpoint - model = TabularClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt" - ) + # 2. Load the model from a checkpoint + model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt") - # 3. Generate predictions from a csv file! Who would survive? - predictions = model.predict("data/titanic/titanic.csv") - print(predictions) + # 3. Generate predictions from a csv file! Who would survive? + predictions = model.predict("data/titanic/titanic.csv") + print(predictions) Serializing predictions ======================= To change how predictions are serialized you can attach a :class:`~flash.core.data.process.Serializer` to your -:class:`~flash.Task`. For example, you can choose to serialize outputs as probabilities (for more options see the API +:class:`~flash.core.model.Task`. For example, you can choose to serialize outputs as probabilities (for more options see the API reference below). .. code-block:: python - from flash.core.classification import Probabilities - from flash.core.data.utils import download_data - from flash.image import ImageClassifier + from flash.core.classification import Probabilities + from flash.core.data.utils import download_data + from flash.image import ImageClassifier - # 1. Download the data set - download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') + # 1. Download the data set + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") - # 2. Load the model from a checkpoint - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + # 2. Load the model from a checkpoint + model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") - # 3. Attach the Serializer - model.serializer = Probabilities() + # 3. Attach the Serializer + model.serializer = Probabilities() - # 4. Predict whether the image contains an ant or a bee - predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") - print(predictions) - # out: [[0.5926494598388672, 0.40735048055648804]] - - ------- - - -****************************************** -Classification serializers - API reference -****************************************** - -.. _logits: - -Logits ---------------- - -.. autoclass:: flash.core.classification.Logits - :members: - :exclude-members: serialize - -.. _probabilities: - -Probabilities ------------------------ - -.. autoclass:: flash.core.classification.Probabilities - :members: - :exclude-members: serialize - -.. _classes: - -Classes ------------------------ - -.. autoclass:: flash.core.classification.Classes - :members: - :exclude-members: serialize - -.. _labels: - -Labels ------------------------ - -.. autoclass:: flash.core.classification.Labels - :members: - :exclude-members: serialize + # 4. Predict whether the image contains an ant or a bee + predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg") + print(predictions) + # out: [[0.5926494598388672, 0.40735048055648804]] diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst index 62ae14c67f..c3d7a96806 100644 --- a/docs/source/general/registry.rst +++ b/docs/source/general/registry.rst @@ -62,6 +62,7 @@ Your custom functions can be registered within a :class:`~flash.core.registry.Fl backbone, num_features = None, None return backbone, num_features + # HINT 1: Use `from functools import partial` if you want to store some arguments. MyImageClassifier.backbones(fn=partial(fn, backbone="my_backbone"), name="username/partial_backbone") @@ -98,9 +99,10 @@ Flash provides populated registries containing lots of available backbones. Example:: - from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES + from flash.image.backbones import OBJ_DETECTION_BACKBONES + from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES - print(IMAGE_CLASSIFIER_BACKBONES.available_models()) + print(IMAGE_CLASSIFIER_BACKBONES.available_keys()) """ out: ['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71'] """ diff --git a/docs/source/general/serve.rst b/docs/source/general/serve.rst index bb2ba5728d..4e09ff6059 100644 --- a/docs/source/general/serve.rst +++ b/docs/source/general/serve.rst @@ -32,7 +32,7 @@ Here are common terms you need to be familiar with: - The :class:`~flash.core.serve.Composition` defines the computations / endpoints to create & run * - :func:`~flash.core.serve.decorators.expose` - The :func:`~flash.core.serve.decorators.expose` function is a python decorator used to - augment the :class:`~flash.core.serve.ModelComponent` inference function with de-serialization, serialization. + augment the :class:`~flash.core.serve.ModelComponent` inference function with de-serialization, serialization. ******* @@ -73,7 +73,7 @@ First, we need make the following imports: from flash.core.serve.types import Image, Label -.. image:: ../_static/images/data_serving_flow.png +.. image:: https://pl-flash-data.s3.amazonaws.com/assets/serve/data_serving_flow.png :width: 100% :alt: Data Serving Flow @@ -175,14 +175,14 @@ Just run: And you should see this in your terminal -.. image:: ../_static/images/inference_server.png +.. image:: https://pl-flash-data.s3.amazonaws.com/assets/serve/inference_server.png :width: 100% :alt: Data Serving Flow You should also see an Swagger UI already built for you at ``http://127.0.0.1:8000/docs`` -.. image:: ../_static/images/swagger_ui.png +.. image:: https://pl-flash-data.s3.amazonaws.com/assets/serve/swagger_ui.png :width: 100% :alt: Data Serving Flow diff --git a/docs/source/governance.rst b/docs/source/governance.rst new file mode 100644 index 0000000000..073c368466 --- /dev/null +++ b/docs/source/governance.rst @@ -0,0 +1,21 @@ +.. _governance: + +Flash Governance | Persons of interest +====================================== + +Leads +----- +- William Falcon (`williamFalcon `_) +- Thomas Chaton (`tchaton `_) +- Ethan Harris (`ethanwharris `_) + +Core Maintainers +---------------- +- Jirka Borovec (`Borda `_) +- Kaushik Bokka (`kaushikb11 `_) +- Justus Schock (`justusschock `_) +- Carlos Mocholí (`carmocca `_) +- Sean Narenthiran (`SeanNaren `_) +- Akihiro Nitta (`akihironitta `_) +- Aniket Maurya (`aniketmaurya `_) +- Ananya Harsh Jha (`ananyahjha93 `_) diff --git a/docs/source/index.rst b/docs/source/index.rst index 92fba5c46a..91ea1a09e5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -24,6 +24,10 @@ Lightning Flash general/finetuning general/predictions general/jit + general/data + general/registry + general/flash_zero + general/serve .. toctree:: :maxdepth: 1 @@ -33,10 +37,19 @@ Lightning Flash reference/image_classification_multi_label reference/image_embedder reference/object_detection + reference/keypoint_detection + reference/instance_segmentation reference/semantic_segmentation reference/style_transfer reference/video_classification +.. toctree:: + :maxdepth: 1 + :caption: Audio + + reference/audio_classification + reference/speech_recognition + .. toctree:: :maxdepth: 1 :caption: Tabular @@ -52,26 +65,42 @@ Lightning Flash reference/summarization reference/translation +.. toctree:: + :maxdepth: 1 + :caption: Point Cloud + + reference/pointcloud_segmentation + reference/pointcloud_object_detection + +.. toctree:: + :maxdepth: 1 + :caption: Graph + + reference/graph_classification + .. toctree:: :maxdepth: 1 :caption: Integrations + integrations/providers integrations/fiftyone + integrations/icevision .. toctree:: :maxdepth: 1 :caption: API Reference - general/data - general/callback - general/registry - general/serve - code/core - code/data - code/image - code/tabular - code/text - code/video + api/flash + api/core + api/data + api/serve + api/image + api/audio + api/pointcloud + api/tabular + api/text + api/video + api/graph .. toctree:: :maxdepth: 1 @@ -86,6 +115,14 @@ Lightning Flash template/tests template/docs +.. toctree:: + :maxdepth: 1 + :caption: Community + + governance + generated/CONTRIBUTING.md + generated/CHANGELOG.md + .. toctree:: :hidden: diff --git a/docs/source/installation.md b/docs/source/installation.md index 0b44b8ddd0..d306090c11 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -12,7 +12,7 @@ Optionally, you can install Flash with extra packages for each domain or all dom ```bash pip install 'lightning-flash[image]' pip install 'lightning-flash[tabular]' -pip install 'lightnign-flash[text]' +pip install 'lightning-flash[text]' pip install 'lightning-flash[video]' # image + video diff --git a/docs/source/integrations/fiftyone.rst b/docs/source/integrations/fiftyone.rst index 25aa342727..8592fad47b 100644 --- a/docs/source/integrations/fiftyone.rst +++ b/docs/source/integrations/fiftyone.rst @@ -1,10 +1,12 @@ +.. _fiftyone: + ######## FiftyOne ######## We have collaborated with the team at -`Voxel51 `_ to integrate their tool, -`FiftyOne `_, into Lightning Flash. +`Voxel51 `__ to integrate their tool, +`FiftyOne `__, into Lightning Flash. FiftyOne is an open-source tool for building high-quality datasets and computer vision models. The FiftyOne API and App enable you to @@ -114,50 +116,3 @@ in only a few lines of code. .. image:: https://pl-flash-data.s3.amazonaws.com/assets/fiftyone/embeddings.png :alt: embeddings_example :align: center - ------- - -************* -API reference -************* - -.. _from_fiftyone: - -DataModule.from_fiftyone ------------------------- - -.. automethod:: flash.core.data.data_module.DataModule.from_fiftyone - :noindex: - -.. _fiftyone_labels: - -FiftyOneLabels --------------- - -.. autoclass:: flash.core.classification.FiftyOneLabels - :members: - -.. _fiftyone_segmentation_labels: - -FiftyOneSegmentationLabels --------------------------- - -.. autoclass:: flash.image.segmentation.serialization.FiftyOneSegmentationLabels - :members: - :noindex: - -.. _fiftyone_detection_labels: - -FiftyOneDetectionLabels ------------------------ - -.. autoclass:: flash.image.detection.serialization.FiftyOneDetectionLabels - :members: - - -.. _fiftyone_visualize: - -visualize ---------- - -.. autofunction:: flash.core.integrations.fiftyone.visualize diff --git a/docs/source/integrations/icevision.rst b/docs/source/integrations/icevision.rst new file mode 100644 index 0000000000..ff21565a4e --- /dev/null +++ b/docs/source/integrations/icevision.rst @@ -0,0 +1,44 @@ +.. _ice_vision: + +######### +IceVision +######### + +IceVision from airctic is an awesome computer vision framework which offers a curated collection of hundreds of high-quality pre-trained models for: object detection, keypoint detection, and instance segmentation. +In Flash, we've integrated the IceVision framework to provide: data loading, augmentation, backbones, and heads. +We use IceVision components in our: :ref:`object detection `, :ref:`instance segmentation `, and :ref:`keypoint detection ` tasks. +Take a look at `their documentation `_ and star `IceVision on GitHub `_ to spread the open source love! + +IceData +_______ + +The `IceData library `_ is a community driven dataset hub for IceVision. +All of the datasets in IceData can be used out of the box with flash using our ``.from_folders`` methods and the ``parser`` argument. +Take a look at our :ref:`keypoint_detection` page for an example. + +Albumentations with IceVision and Flash +_______________________________________ + +IceVision provides two utilities for using the `albumentations library `_ with their models: +- the ``Adapter`` helper class for adapting an any albumentations transform to work with IceVision records, +- the ``aug_tfms`` utility function that returns a standard augmentation recipe to get the most out of your model. + +In Flash, we use the ``aug_tfms`` as default transforms for the: :ref:`object detection `, :ref:`instance segmentation `, and :ref:`keypoint detection ` tasks. +You can also provide custom transforms from albumentations using the :class:`~flash.core.integrations.icevision.transforms.IceVisionTransformAdapter` (which relies on the IceVision ``Adapter`` underneath). +Here's an example: + +.. code-block:: python + + import albumentations as A + + from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter + from flash.image import ObjectDetectionData + + train_transform = { + "pre_tensor_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]), + } + + datamodule = ObjectDetectionData.from_coco( + ..., + train_transform=train_transform, + ) diff --git a/docs/source/integrations/providers.rst b/docs/source/integrations/providers.rst new file mode 100644 index 0000000000..7254acd6cf --- /dev/null +++ b/docs/source/integrations/providers.rst @@ -0,0 +1,15 @@ +.. _providers: + +######### +Providers +######### + +Flash is a framework integrator. +We rely on many open source frameworks for our tasks, visualizations and backbones. +Here's a list of some of the providers we use for backbones and heads within Flash (check them out and star their repos to spread the open source love!): + +.. include:: generated/providers.rst + +You can also read our guides for some of our larger integrations: + +- :ref:`fiftyone` diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index f07d739488..85cf5b6f53 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -98,11 +98,13 @@ Here's an example of inference: model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt") # 2. Perform inference from list of sequences - predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "This guy has done a great job with this movie!", - ]) + predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "This guy has done a great job with this movie!", + ] + ) print(predictions) We get the following output: diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst new file mode 100644 index 0000000000..97c8df79b3 --- /dev/null +++ b/docs/source/reference/audio_classification.rst @@ -0,0 +1,92 @@ + +.. _audio_classification: + +#################### +Audio Classification +#################### + +******** +The Task +******** + +The task of identifying what is in an audio file is called audio classification. +Typically, Audio Classification is used to identify audio files containing sounds or words. +The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. +A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc. + +------ + +******* +Example +******* + +Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. +The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes. + +.. code-block:: + + urban8k_images + ├── train + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + ├── test + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + └── val + ├── air_conditioner + ├── car_horn + ├── children_playing + ├── dog_bark + ├── drilling + ├── engine_idling + ├── gun_shot + ├── jackhammer + ├── siren + └── street_music + + ... + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`. +We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data. +We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/audio_classification.py + :language: python + :lines: 14- + +------ + +********** +Flash Zero +********** + +The audio classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash audio_classification + +To view configuration options and options for running the audio classifier with your own data, use: + +.. code-block:: bash + + flash audio_classification --help diff --git a/docs/source/reference/graph_classification.rst b/docs/source/reference/graph_classification.rst new file mode 100644 index 0000000000..84cc8d12d4 --- /dev/null +++ b/docs/source/reference/graph_classification.rst @@ -0,0 +1,52 @@ +.. _graph_classification: + +#################### +Graph Classification +#################### + +******** +The Task +******** +This task consist on classifying graphs. +The task predicts which ‘class’ the graph belongs to. +A class is a label that indicates the kind of graph. +For example, a label may indicate whether one molecule interacts with another. + +The :class:`~flash.graph.classification.model.GraphClassifier` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric `_. + +------ + +******* +Example +******* + +Let's look at the task of classifying graphs from the KKI data set from `TU Dortmund University `_. + +Once we've created the `TUDataset `_, we create the :class:`~flash.graph.classification.data.GraphClassificationData`. +We then create our :class:`~flash.graph.classification.model.GraphClassifier` and train on the KKI data. +Next, we use the trained :class:`~flash.graph.classification.model.GraphClassifier` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/graph_classification.py + :language: python + :lines: 14- + +------ + +********** +Flash Zero +********** + +The graph classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash graph_classification + +To view configuration options and options for running the graph classifier with your own data, use: + +.. code-block:: bash + + flash graph_classification --help diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 484abbc142..4116128f2a 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -57,12 +57,82 @@ Here's the full example: ------ +********************** +Custom Transformations +********************** + +Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case. +The base :class:`~flash.core.data.process.Preprocess` defines 7 hooks for different stages in the data loading pipeline. +To apply image augmentations you can directly import the ``default_transforms`` from ``flash.image.classification.transforms`` and then merge your custom image transformations with them using the :func:`~flash.core.data.transforms.merge_transforms` helper function. +Here's an example where we load the default transforms and merge with custom `torchvision` transformations. +We use the `post_tensor_transform` hook to apply the transformations after the image has been converted to a `torch.Tensor`. + + +.. testsetup:: transformations + + from flash.core.data.utils import download_data + + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") + +.. testcode:: transformations + + from torchvision import transforms as T + + import flash + from flash.core.data.data_source import DefaultDataKeys + from flash.core.data.transforms import ApplyToKeys, merge_transforms + from flash.image import ImageClassificationData, ImageClassifier + from flash.image.classification.transforms import default_transforms + + post_tensor_transform = ApplyToKeys( + DefaultDataKeys.INPUT, + T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]), + ) + + new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform}) + + datamodule = ImageClassificationData.from_folders( + train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms + ) + + model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + + trainer = flash.Trainer(max_epochs=1) + trainer.finetune(model, datamodule=datamodule, strategy="freeze") + + +.. testoutput:: transformations + :hide: + + ... + +------ + +********** +Flash Zero +********** + +The image classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the hymenoptera example with: + +.. code-block:: bash + + flash image_classification + +To view configuration options and options for running the image classifier with your own data, use: + +.. code-block:: bash + + flash image_classification --help + +------ + ******* Serving ******* The :class:`~flash.image.classification.model.ImageClassifier` is servable. -This means you can call ``.serve`` to serve your :class:`~flash.Task`. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: .. literalinclude:: ../../../flash_examples/serve/image_classification/inference_server.py diff --git a/docs/source/reference/image_classification_multi_label.rst b/docs/source/reference/image_classification_multi_label.rst index c570a1f186..77e447c705 100644 --- a/docs/source/reference/image_classification_multi_label.rst +++ b/docs/source/reference/image_classification_multi_label.rst @@ -49,6 +49,27 @@ Here's the full example: :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The multi-label image classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the movie posters example with: + +.. code-block:: bash + + flash image_classification from_movie_posters + +To view configuration options and options for running the image classifier with your own data, use: + +.. code-block:: bash + + flash image_classification --help + + ------ ******* diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst new file mode 100644 index 0000000000..db864ad2bc --- /dev/null +++ b/docs/source/reference/instance_segmentation.rst @@ -0,0 +1,50 @@ + +.. _instance_segmentation: + +##################### +Instance Segmentation +##################### + +******** +The Task +******** + +Instance segmentation is the task of segmenting objects images and determining their associated classes. + +The :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` classes internally rely on `IceVision `_. + +------ + +******* +Example +******* + +Let's look at instance segmentation with `The Oxford-IIIT Pet Dataset `_ from `IceData `_. +Once we've downloaded the data, we can create the :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData`. +We select a ``mask_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` and fine-tune on the pets data. +We then use the trained :class:`~flash.image.instance_segmentation.model.InstanceSegmentation` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/instance_segmentation.py + :language: python + :lines: 14- + +------ + +********** +Flash Zero +********** + +The instance segmentation task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash instance_segmentation + +To view configuration options and options for running the instance segmentation task with your own data, use: + +.. code-block:: bash + + flash instance_segmentation --help diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst new file mode 100644 index 0000000000..2cc0fbef40 --- /dev/null +++ b/docs/source/reference/keypoint_detection.rst @@ -0,0 +1,50 @@ + +.. _keypoint_detection: + +################## +Keypoint Detection +################## + +******** +The Task +******** + +Keypoint detection is the task of identifying keypoints in images and their associated classes. + +The :class:`~flash.image.keypoint_detection.model.KeypointDetector` and :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` classes internally rely on `IceVision `_. + +------ + +******* +Example +******* + +Let's look at keypoint detection with `BIWI Sample Keypoints (center of face) `_ from `IceData `_. +Once we've downloaded the data, we can create the :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`. +We select a ``keypoint_rcnn`` with a ``resnet18_fpn`` backbone to use for our :class:`~flash.image.keypoint_detection.model.KeypointDetector` and fine-tune on the BIWI data. +We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDetector` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/keypoint_detection.py + :language: python + :lines: 14- + +------ + +********** +Flash Zero +********** + +The keypoint detector can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash keypoint_detection + +To view configuration options and options for running the keypoint detector with your own data, use: + +.. code-block:: bash + + flash keypoint_detection --help diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index bf82bec153..0bf34c07c3 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -11,6 +11,8 @@ The Task Object detection is the task of identifying objects in images and their associated classes and bounding boxes. +The :class:`~flash.image.detection.model.ObjectDetector` and :class:`~flash.image.detection.data.ObjectDetectionData` classes internally rely on `IceVision `_. + ------ ******* @@ -47,3 +49,22 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/object_detection.py :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The object detector can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash object_detection + +To view configuration options and options for running the object detector with your own data, use: + +.. code-block:: bash + + flash object_detection --help diff --git a/docs/source/reference/pointcloud_object_detection.rst b/docs/source/reference/pointcloud_object_detection.rst new file mode 100644 index 0000000000..1be71919f3 --- /dev/null +++ b/docs/source/reference/pointcloud_object_detection.rst @@ -0,0 +1,99 @@ + +.. _pointcloud_object_detection: + +############################ +Point Cloud Object Detection +############################ + +******** +The Task +******** + +A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates. + +PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes. + +The current integration builds on top `Open3D-ML `_. + +------ + +******* +Example +******* + +Let's look at an example using a data set generated from the `KITTI Vision Benchmark `_. +The data are a tiny subset of the original dataset and contains sequences of point clouds. + +The data contains: + * one folder for scans + * one folder for scan calibrations + * one folder for labels + * a meta.yaml file describing the classes and their official associated color map. + +Here's the structure: + +.. code-block:: + + data + ├── meta.yaml + ├── train + │ ├── scans + | | ├── 00000.bin + | | ├── 00001.bin + | | ... + │ ├── calibs + | | ├── 00000.txt + | | ├── 00001.txt + | | ... + │ ├── labels + | | ├── 00000.txt + | | ├── 00001.txt + │ ... + ├── val + │ ... + ├── predict + ├── scans + | ├── 00000.bin + | ├── 00001.bin + | + ├── calibs + | ├── 00000.txt + | ├── 00001.txt + ├── meta.yaml + + + +Learn more: http://www.semantic-kitti.org/dataset.html + + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.detection.data.PointCloudObjectDetectorData`. +We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.detection.model.PointCloudObjectDetector` task. +We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDetector` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/pointcloud_detection.py + :language: python + :lines: 14- + +.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png + :width: 100% + +------ + +********** +Flash Zero +********** + +The point cloud object detector can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash pointcloud_detection + +To view configuration options and options for running the point cloud object detector with your own data, use: + +.. code-block:: bash + + flash pointcloud_detection --help diff --git a/docs/source/reference/pointcloud_segmentation.rst b/docs/source/reference/pointcloud_segmentation.rst new file mode 100644 index 0000000000..1777313521 --- /dev/null +++ b/docs/source/reference/pointcloud_segmentation.rst @@ -0,0 +1,90 @@ + +.. _pointcloud_segmentation: + +######################## +Point Cloud Segmentation +######################## + +******** +The Task +******** + +A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates. + +PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. +The current integration builds on top `Open3D-ML `_. + +------ + +******* +Example +******* + +Let's look at an example using a data set generated from the `KITTI Vision Benchmark `_. +The data are a tiny subset of the original dataset and contains sequences of point clouds. +The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map. +A sequence should contain one folder for scans and one folder for labels, plus a ``pose.txt`` to re-align the sequence if required. +Here's the structure: + +.. code-block:: + + data + ├── meta.yaml + ├── 00 + │ ├── scans + | | ├── 00000.bin + | | ├── 00001.bin + | | ... + │ ├── labels + | | ├── 00000.label + | | ├── 00001.label + | | ... + | ├── pose.txt + │ ... + | + └── XX + ├── scans + | ├── 00000.bin + | ├── 00001.bin + | ... + ├── labels + | ├── 00000.label + | ├── 00001.label + | ... + ├── pose.txt + + +Learn more: http://www.semantic-kitti.org/dataset.html + + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the ``PointCloudSegmentationData``. +We select a pre-trained ``randlanet_semantic_kitti`` backbone for our ``PointCloudSegmentation`` task. +We then use the trained ``PointCloudSegmentation`` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/pointcloud_segmentation.py + :language: python + :lines: 14- + +.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif + :width: 100% + +------ + +********** +Flash Zero +********** + +The point cloud segmentation task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash pointcloud_segmentation + +To view configuration options and options for running the point cloud segmentation task with your own data, use: + +.. code-block:: bash + + flash pointcloud_segmentation --help diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index b8deabd800..92cbe67314 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -36,7 +36,7 @@ Here's the structure: Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.SemanticSegmentationData`. We select a pre-trained ``mobilenet_v3_large`` backbone with an ``fpn`` head to use for our :class:`~flash.image.segmentation.model.SemanticSegmentation` task and fine-tune on the CARLA data. -We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. +We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. You can check the available pretrained weights for the backbones like this `SemanticSegmentation.available_pretrained_weights("resnet18")`. Finally, we save the model. Here's the full example: @@ -44,6 +44,27 @@ Here's the full example: :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The semantic segmentation task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash semantic_segmentation + +To view configuration options and options for running the semantic segmentation task with your own data, use: + +.. code-block:: bash + + flash semantic_segmentation --help + + ------ ******* @@ -51,7 +72,7 @@ Serving ******* The :class:`~flash.image.segmentation.model.SemanticSegmentation` task is servable. -This means you can call ``.serve`` to serve your :class:`~flash.Task`. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: .. literalinclude:: ../../../flash_examples/serve/semantic_segmentation/inference_server.py diff --git a/docs/source/reference/speech_recognition.rst b/docs/source/reference/speech_recognition.rst new file mode 100644 index 0000000000..2b6918078c --- /dev/null +++ b/docs/source/reference/speech_recognition.rst @@ -0,0 +1,87 @@ +.. _speech_recognition: + +################## +Speech Recognition +################## + +******** +The Task +******** + +Speech recognition is the task of classifying audio into a text transcription. We rely on `Wav2Vec `_ as our backbone, fine-tuned on labeled transcriptions for speech to text. +Wav2Vec is pre-trained on thousand of hours of unlabeled audio, providing a strong baseline when fine-tuning to downstream tasks such as Speech Recognition. + +----- + +******* +Example +******* + +Let's fine-tune the model onto our own labeled audio transcription data: + +Here's the structure our CSV file: + +.. code-block:: + + file,text + "/path/to/file_1.wav","what was said in file 1." + "/path/to/file_2.wav","what was said in file 2." + "/path/to/file_3.wav","what was said in file 3." + ... + +Alternatively, here is the structure of our JSON file: + +.. code-block:: + + {"file": "/path/to/file_1.wav", "text": "what was said in file 1."} + {"file": "/path/to/file_2.wav", "text": "what was said in file 2."} + {"file": "/path/to/file_3.wav", "text": "what was said in file 3."} + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`. +We select a pre-trained Wav2Vec backbone to use for our :class:`~flash.audio.speech_recognition.model.SpeechRecognition` and finetune on a subset of the `TIMIT corpus `__. +The backbone can be any Wav2Vec model from `HuggingFace transformers `__. +Next, we use the trained :class:`~flash.audio.speech_recognition.model.SpeechRecognition` for inference and save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/speech_recognition.py + :language: python + :lines: 14- + +------ + +********** +Flash Zero +********** + +The speech recognition task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash speech_recognition + +To view configuration options and options for running the speech recognition task with your own data, use: + +.. code-block:: bash + + flash speech_recognition --help + +------ + +******* +Serving +******* + +The :class:`~flash.audio.speech_recognition.model.SpeechRecognition` is servable. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. +Here's an example: + +.. literalinclude:: ../../../flash_examples/serve/speech_recognition/inference_server.py + :language: python + :lines: 14- + +You can now perform inference from your client like this: + +.. literalinclude:: ../../../flash_examples/serve/speech_recognition/client.py + :language: python + :lines: 14- diff --git a/docs/source/reference/style_transfer.rst b/docs/source/reference/style_transfer.rst index 175cf21426..4b19c940ef 100644 --- a/docs/source/reference/style_transfer.rst +++ b/docs/source/reference/style_transfer.rst @@ -12,7 +12,7 @@ The Task The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image. -.. image:: https://raw.githubusercontent.com/pystiche/pystiche/master/docs/source/graphics/banner/banner.jpg +.. image:: https://raw.githubusercontent.com/pystiche/pystiche/main/docs/source/graphics/banner/banner.jpg :alt: style_transfer_example The :class:`~flash.image.style_transfer.model.StyleTransfer` and :class:`~flash.image.style_transfer.data.StyleTransferData` classes internally rely on `pystiche `_. @@ -33,3 +33,22 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/style_transfer.py :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The style transfer task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash style_transfer + +To view configuration options and options for running the style transfer task with your own data, use: + +.. code-block:: bash + + flash style_transfer --help diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst index 48dfa58134..6010324cb1 100644 --- a/docs/source/reference/summarization.rst +++ b/docs/source/reference/summarization.rst @@ -49,12 +49,31 @@ Here's the full example: ------ +********** +Flash Zero +********** + +The summarization task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash summarization + +To view configuration options and options for running the summarization task with your own data, use: + +.. code-block:: bash + + flash summarization --help + +------ + ******* Serving ******* The :class:`~flash.text.seq2seq.summarization.model.SummarizationTask` is servable. -This means you can call ``.serve`` to serve your :class:`~flash.Task`. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: .. literalinclude:: ../../../flash_examples/serve/summarization/inference_server.py @@ -66,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/summarization/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``SummarizationTask`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index ab4d4b85f2..48ce18a872 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -48,12 +48,31 @@ Here's the full example: ------ +********** +Flash Zero +********** + +The tabular classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash tabular_classifier + +To view configuration options and options for running the tabular classifier with your own data, use: + +.. code-block:: bash + + flash tabular_classifier --help + +------ + ******* Serving ******* The :class:`~flash.tabular.classification.model.TabularClassifier` is servable. -This means you can call ``.serve`` to serve your :class:`~flash.Task`. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: .. literalinclude:: ../../../flash_examples/serve/tabular_classification/inference_server.py diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index a27d04412d..989ce2e387 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -49,12 +49,31 @@ Here's the full example: ------ +********** +Flash Zero +********** + +The text classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash text_classification + +To view configuration options and options for running the text classifier with your own data, use: + +.. code-block:: bash + + flash text_classification --help + +------ + ******* Serving ******* The :class:`~flash.text.classification.model.TextClassifier` is servable. -This means you can call ``.serve`` to serve your :class:`~flash.Task`. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: .. literalinclude:: ../../../flash_examples/serve/text_classification/inference_server.py @@ -66,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/text_classification/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TextClassifier`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/docs/source/reference/text_classification_multi_label.rst b/docs/source/reference/text_classification_multi_label.rst index 6b65ae5a6f..54929122ab 100644 --- a/docs/source/reference/text_classification_multi_label.rst +++ b/docs/source/reference/text_classification_multi_label.rst @@ -47,6 +47,25 @@ Here's the full example: ------ +********** +Flash Zero +********** + +The multi-label text classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash text_classification from_toxic + +To view configuration options and options for running the text classifier with your own data, use: + +.. code-block:: bash + + flash text_classification --help + +------ + ******* Serving ******* diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index 7fde16297d..cc7c21c517 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -49,12 +49,31 @@ Here's the full example: ------ +********** +Flash Zero +********** + +The translation task can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash translation + +To view configuration options and options for running the translation task with your own data, use: + +.. code-block:: bash + + flash translation --help + +------ + ******* Serving ******* The :class:`~flash.text.seq2seq.translation.model.TranslationTask` is servable. -This means you can call ``.serve`` to serve your :class:`~flash.Task`. +This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: .. literalinclude:: ../../../flash_examples/serve/translation/inference_server.py @@ -66,3 +85,21 @@ You can now perform inference from your client like this: .. literalinclude:: ../../../flash_examples/serve/translation/client.py :language: python :lines: 14- + +------ + +********************************************** +Accelerate Training & Inference with Torch ORT +********************************************** + +`Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``TranslationTask`` once installed. See installation instructions `here `__. + +.. note:: + + Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. + +.. code-block:: python + + ... + + model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True) diff --git a/docs/source/reference/video_classification.rst b/docs/source/reference/video_classification.rst index 9fb40c9569..4a60280ad8 100644 --- a/docs/source/reference/video_classification.rst +++ b/docs/source/reference/video_classification.rst @@ -56,3 +56,22 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/video_classification.py :language: python :lines: 14- + +------ + +********** +Flash Zero +********** + +The video classifier can be used directly from the command line with zero code using :ref:`flash_zero`. +You can run the above example with: + +.. code-block:: bash + + flash video_classification + +To view configuration options and options for running the video classifier with your own data, use: + +.. code-block:: bash + + flash video_classification --help diff --git a/docs/source/template/backbones.rst b/docs/source/template/backbones.rst index 82c629430f..bcbac896a2 100644 --- a/docs/source/template/backbones.rst +++ b/docs/source/template/backbones.rst @@ -34,11 +34,11 @@ Here's another example with a slightly more complex model: :language: python :pyobject: load_mlp_128_256 -Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py `_: +Here's a another example, which adds ``DINO`` pretrained model from PyTorch Hub to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/classification/backbones/transformers.py `_: -.. literalinclude:: ../../../flash/image/backbones.py +.. literalinclude:: ../../../flash/image/classification/backbones/transformers.py :language: python - :pyobject: load_simclr_imagenet + :pyobject: dino_vitb16 ------ diff --git a/docs/source/template/tests.rst b/docs/source/template/tests.rst index 0c3dd9f228..33d85952fb 100644 --- a/docs/source/template/tests.rst +++ b/docs/source/template/tests.rst @@ -24,15 +24,11 @@ Here's how those lines look for our ``template.py`` examples: .. code-block:: python pytest.param( - "finetuning", - "template.py", - marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") + "finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") ), ... pytest.param( - "predict", - "template.py", - marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") + "predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") ), test_data.py diff --git a/flash/__about__.py b/flash/__about__.py index d66522a669..eab8629bc9 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,7 +1,7 @@ -__version__ = "0.4.1dev" +__version__ = "0.5.0dev" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" -__license__ = 'Apache-2.0' +__license__ = "Apache-2.0" __copyright__ = f"Copyright (c) 2020-2021, f{__author__}." __homepage__ = "https://github.com/PyTorchLightning/lightning-flash" __docs_url__ = "https://lightning-flash.readthedocs.io/en/stable/" diff --git a/flash/__init__.py b/flash/__init__.py index 7a13f9d20b..e8321350c9 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -33,6 +33,7 @@ if _IS_TESTING: from pytorch_lightning import seed_everything + seed_everything(42) __all__ = [ diff --git a/flash/__main__.py b/flash/__main__.py new file mode 100644 index 0000000000..fba73c4fac --- /dev/null +++ b/flash/__main__.py @@ -0,0 +1,71 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import importlib +from unittest.mock import patch + +import click + + +@click.group(no_args_is_help=True) +def main(): + """The Lightning-Flash zero-code command line utility.""" + + +def register_command(command): + @main.command( + command.__name__, + context_settings=dict( + help_option_names=[], + ignore_unknown_options=True, + ), + ) + @click.argument("cli_args", nargs=-1, type=click.UNPROCESSED) + @functools.wraps(command) + def wrapper(cli_args): + with patch("sys.argv", [command.__name__] + list(cli_args)): + command() + + +tasks = [ + "flash.audio.classification", + "flash.audio.speech_recognition", + "flash.graph.classification", + "flash.image.classification", + "flash.image.detection", + "flash.image.instance_segmentation", + "flash.image.keypoint_detection", + "flash.image.segmentation", + "flash.image.style_transfer", + "flash.pointcloud.detection", + "flash.pointcloud.segmentation", + "flash.tabular.classification", + "flash.text.classification", + "flash.text.seq2seq.summarization", + "flash.text.seq2seq.translation", + "flash.video.classification", +] + +for task in tasks: + try: + task = importlib.import_module(f"{task}.cli") + + for command in task.__all__: + command = task.__dict__[command] + register_command(command) + except ImportError: + pass + +if __name__ == "__main__": + main() diff --git a/flash/assets/example.wav b/flash/assets/example.wav new file mode 100644 index 0000000000..8a1d66a36b Binary files /dev/null and b/flash/assets/example.wav differ diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py new file mode 100644 index 0000000000..b90bc6d06e --- /dev/null +++ b/flash/audio/__init__.py @@ -0,0 +1,2 @@ +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 +from flash.audio.speech_recognition import SpeechRecognition, SpeechRecognitionData # noqa: F401 diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py new file mode 100644 index 0000000000..476a303d49 --- /dev/null +++ b/flash/audio/classification/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/flash/audio/classification/cli.py b/flash/audio/classification/cli.py new file mode 100644 index 0000000000..c198a99239 --- /dev/null +++ b/flash/audio/classification/cli.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.audio import AudioClassificationData +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.image import ImageClassifier + +__all__ = ["audio_classification"] + + +def from_urban8k( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> AudioClassificationData: + """Downloads and loads the Urban 8k sounds images data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") + return AudioClassificationData.from_folders( + train_folder="data/urban8k_images/train", + val_folder="data/urban8k_images/val", + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def audio_classification(): + """Classify audio spectrograms.""" + cli = FlashCLI( + ImageClassifier, + AudioClassificationData, + default_datamodule_builder=from_urban8k, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("audio_classification_model.pt") + + +if __name__ == "__main__": + audio_classification() diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py new file mode 100644 index 0000000000..ac0748e666 --- /dev/null +++ b/flash/audio/classification/data.py @@ -0,0 +1,112 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np + +from flash.audio.classification.transforms import default_transforms, train_default_transforms +from flash.core.data.data_source import ( + DefaultDataSources, + has_file_allowed_extension, + LoaderDataFrameDataSource, + PathsDataSource, +) +from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, requires_extras +from flash.image.classification.data import ImageClassificationData +from flash.image.data import ImageDeserializer + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS + + +NP_EXTENSIONS = (".npy", ".npz") + + +def spectrogram_loader(filepath: str): + if has_file_allowed_extension(filepath, IMG_EXTENSIONS): + img = default_loader(filepath) + data = np.array(img) + else: + data = np.load(filepath) + return data + + +class AudioClassificationPathsDataSource(PathsDataSource): + @requires_extras("image") + def __init__(self): + super().__init__(loader=spectrogram_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) + + +class AudioClassificationDataFrameDataSource(LoaderDataFrameDataSource): + @requires_extras("image") + def __init__(self): + super().__init__(spectrogram_loader) + + +class AudioClassificationPreprocess(Preprocess): + @requires_extras(["audio", "image"]) + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + spectrogram_size: Tuple[int, int] = (128, 128), + time_mask_param: int = 80, + freq_mask_param: int = 80, + deserializer: Optional["Deserializer"] = None, + ): + self.spectrogram_size = spectrogram_size + self.time_mask_param = time_mask_param + self.freq_mask_param = freq_mask_param + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FILES: AudioClassificationPathsDataSource(), + DefaultDataSources.FOLDERS: AudioClassificationPathsDataSource(), + "data_frame": AudioClassificationDataFrameDataSource(), + DefaultDataSources.CSV: AudioClassificationDataFrameDataSource(), + }, + deserializer=deserializer or ImageDeserializer(), + default_data_source=DefaultDataSources.FILES, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "spectrogram_size": self.spectrogram_size, + "time_mask_param": self.time_mask_param, + "freq_mask_param": self.freq_mask_param, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.spectrogram_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) + + +class AudioClassificationData(ImageClassificationData): + """Data module for audio classification.""" + + preprocess_cls = AudioClassificationPreprocess diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py new file mode 100644 index 0000000000..04599ffd17 --- /dev/null +++ b/flash/audio/classification/transforms.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Tuple + +import torch +from torch import nn +from torch.utils.data._utils.collate import default_collate + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys, merge_transforms +from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision + from torchvision import transforms as T + +if _TORCHAUDIO_AVAILABLE: + from torchaudio import transforms as TAudio + + +def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the + spectrogram and target to a tensor, and collate the batch.""" + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), + "collate": default_collate, + } + + +def train_default_transforms( + spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int +) -> Dict[str, Callable]: + """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" + transforms = { + "post_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)), + ) + } + + return merge_transforms(default_transforms(spectrogram_size), transforms) diff --git a/flash/audio/speech_recognition/__init__.py b/flash/audio/speech_recognition/__init__.py new file mode 100644 index 0000000000..00f1b6fa0c --- /dev/null +++ b/flash/audio/speech_recognition/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.audio.speech_recognition.data import SpeechRecognitionData # noqa: F401 +from flash.audio.speech_recognition.model import SpeechRecognition # noqa: F401 diff --git a/flash/audio/speech_recognition/backbone.py b/flash/audio/speech_recognition/backbone.py new file mode 100644 index 0000000000..e583d7366a --- /dev/null +++ b/flash/audio/speech_recognition/backbone.py @@ -0,0 +1,32 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _AUDIO_AVAILABLE +from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE + +SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") + +if _AUDIO_AVAILABLE: + from transformers import Wav2Vec2ForCTC + + WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"] + + for model_name in WAV2VEC_MODELS: + SPEECH_RECOGNITION_BACKBONES( + fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name), + name=model_name, + providers=[_HUGGINGFACE, _FAIRSEQ], + ) diff --git a/flash/audio/speech_recognition/cli.py b/flash/audio/speech_recognition/cli.py new file mode 100644 index 0000000000..9bbdb48df8 --- /dev/null +++ b/flash/audio/speech_recognition/cli.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.audio import SpeechRecognition, SpeechRecognitionData +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI + +__all__ = ["speech_recognition"] + + +def from_timit( + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> SpeechRecognitionData: + """Downloads and loads the timit data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") + return SpeechRecognitionData.from_json( + input_fields="file", + target_fields="text", + train_file="data/timit/train.json", + test_file="data/timit/test.json", + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def speech_recognition(): + """Speech recognition.""" + cli = FlashCLI( + SpeechRecognition, + SpeechRecognitionData, + default_datamodule_builder=from_timit, + default_arguments={ + "trainer.max_epochs": 3, + }, + finetune=False, + ) + + cli.trainer.save_checkpoint("speech_recognition_model.pt") + + +if __name__ == "__main__": + speech_recognition() diff --git a/flash/audio/speech_recognition/collate.py b/flash/audio/speech_recognition/collate.py new file mode 100644 index 0000000000..9ee53a4686 --- /dev/null +++ b/flash/audio/speech_recognition/collate.py @@ -0,0 +1,101 @@ +# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _AUDIO_AVAILABLE + +if _AUDIO_AVAILABLE: + from transformers import Wav2Vec2Processor +else: + Wav2Vec2Processor = object + + +@dataclass +class DataCollatorCTCWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor (:class:`~transformers.Wav2Vec2Processor`) + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, + `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + max_length_labels (:obj:`int`, `optional`): + Maximum length of the ``labels`` returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Wav2Vec2Processor + padding: Union[bool, str] = True + max_length: Optional[int] = None + max_length_labels: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + pad_to_multiple_of_labels: Optional[int] = None + + def __call__(self, samples: List[Dict[str, Any]], metadata: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + inputs = [sample[DefaultDataKeys.INPUT] for sample in samples] + sampling_rates = [sample["sampling_rate"] for sample in metadata] + + assert ( + len(set(sampling_rates)) == 1 + ), f"Make sure all inputs have the same sampling rate of {self.processor.feature_extractor.sampling_rate}." + + inputs = self.processor(inputs, sampling_rate=sampling_rates[0]).input_values + + # split inputs and labels since they have to be of different lengths and need + # different padding methods + input_features = [{"input_values": input} for input in inputs] + + batch = self.processor.pad( + input_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + labels = [sample.get(DefaultDataKeys.TARGET, None) for sample in samples] + # check to ensure labels exist to collate + if None not in labels: + with self.processor.as_target_processor(): + label_features = self.processor(labels).input_ids + label_features = [{"input_ids": feature} for feature in label_features] + labels_batch = self.processor.pad( + label_features, + padding=self.padding, + max_length=self.max_length_labels, + pad_to_multiple_of=self.pad_to_multiple_of_labels, + return_tensors="pt", + ) + + # replace padding with -100 to ignore loss correctly + batch["labels"] = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + return batch diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py new file mode 100644 index 0000000000..029419b50b --- /dev/null +++ b/flash/audio/speech_recognition/data.py @@ -0,0 +1,220 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import io +import os.path +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import torch +from torch.utils.data import Dataset + +import flash +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import ( + DatasetDataSource, + DataSource, + DefaultDataKeys, + DefaultDataSources, + PathsDataSource, +) +from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.properties import ProcessState +from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires_extras + +if _AUDIO_AVAILABLE: + import soundfile as sf + from datasets import Dataset as HFDataset + from datasets import load_dataset + from transformers import Wav2Vec2CTCTokenizer +else: + HFDataset = object + + +class SpeechRecognitionDeserializer(Deserializer): + def deserialize(self, sample: Any) -> Dict: + encoded_with_padding = (sample + "===").encode("ascii") + audio = base64.b64decode(encoded_with_padding) + buffer = io.BytesIO(audio) + data, sampling_rate = sf.read(buffer) + return { + DefaultDataKeys.INPUT: data, + DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate}, + } + + @property + def example_input(self) -> str: + with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f: + return base64.b64encode(f.read()).decode("UTF-8") + + +class BaseSpeechRecognition: + def _load_sample(self, sample: Dict[str, Any]) -> Any: + path = sample[DefaultDataKeys.INPUT] + if ( + not os.path.isabs(path) + and DefaultDataKeys.METADATA in sample + and "root" in sample[DefaultDataKeys.METADATA] + ): + path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path) + speech_array, sampling_rate = sf.read(path) + sample[DefaultDataKeys.INPUT] = speech_array + sample[DefaultDataKeys.METADATA] = {"sampling_rate": sampling_rate} + return sample + + +class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition): + def __init__(self, filetype: Optional[str] = None): + super().__init__() + self.filetype = filetype + + def load_data( + self, + data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], + dataset: Optional[Any] = None, + ) -> Union[Sequence[Mapping[str, Any]]]: + if self.filetype == "json": + file, input_key, target_key, field = data + else: + file, input_key, target_key = data + stage = self.running_stage.value + if self.filetype == "json" and field is not None: + dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}) + + dataset = dataset_dict[stage] + meta = {"root": os.path.dirname(file)} + return [ + { + DefaultDataKeys.INPUT: input_file, + DefaultDataKeys.TARGET: target, + DefaultDataKeys.METADATA: meta, + } + for input_file, target in zip(dataset[input_key], dataset[target_key]) + ] + + def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: + return self._load_sample(sample) + + +class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource): + def __init__(self): + super().__init__(filetype="csv") + + +class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource): + def __init__(self): + super().__init__(filetype="json") + + +class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition): + def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]: + if isinstance(data, HFDataset): + data = list(zip(data["file"], data["text"])) + return super().load_data(data, dataset) + + +class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition): + def __init__(self): + super().__init__(("wav", "ogg", "flac", "mat")) + + def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any: + return self._load_sample(sample) + + +class SpeechRecognitionPreprocess(Preprocess): + @requires_extras("audio") + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: SpeechRecognitionCSVDataSource(), + DefaultDataSources.JSON: SpeechRecognitionJSONDataSource(), + DefaultDataSources.FILES: SpeechRecognitionPathsDataSource(), + DefaultDataSources.DATASETS: SpeechRecognitionDatasetDataSource(), + }, + default_data_source=DefaultDataSources.FILES, + deserializer=SpeechRecognitionDeserializer(), + ) + + def get_state_dict(self) -> Dict[str, Any]: + return self.transforms + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + +@dataclass(unsafe_hash=True, frozen=True) +class SpeechRecognitionBackboneState(ProcessState): + """The ``SpeechRecognitionBackboneState`` stores the backbone in use by the + :class:`~flash.audio.speech_recognition.data.SpeechRecognitionPostprocess` + """ + + backbone: str + + +class SpeechRecognitionPostprocess(Postprocess): + @requires_extras("audio") + def __init__(self): + super().__init__() + + self._backbone = None + self._tokenizer = None + + @property + def backbone(self): + backbone_state = self.get_state(SpeechRecognitionBackboneState) + if backbone_state is not None: + return backbone_state.backbone + + @property + def tokenizer(self): + if self.backbone is not None and self.backbone != self._backbone: + self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) + self._backbone = self.backbone + return self._tokenizer + + def per_batch_transform(self, batch: Any) -> Any: + # converts logits into greedy transcription + pred_ids = torch.argmax(batch.logits, dim=-1) + transcriptions = self.tokenizer.batch_decode(pred_ids) + return transcriptions + + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("_tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) + + +class SpeechRecognitionData(DataModule): + """Data Module for text classification tasks.""" + + preprocess_cls = SpeechRecognitionPreprocess + postprocess_cls = SpeechRecognitionPostprocess diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py new file mode 100644 index 0000000000..15cdcef4f9 --- /dev/null +++ b/flash/audio/speech_recognition/model.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from typing import Any, Callable, Dict, Mapping, Optional, Type, Union + +import torch +import torch.nn as nn + +from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES +from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding +from flash.audio.speech_recognition.data import SpeechRecognitionBackboneState +from flash.core.data.process import Serializer +from flash.core.data.states import CollateFn +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _AUDIO_AVAILABLE + +if _AUDIO_AVAILABLE: + from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +class SpeechRecognition(Task): + + backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES + + required_extras = "audio" + + def __init__( + self, + backbone: str = "facebook/wav2vec2-base-960h", + loss_fn: Optional[Callable] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + learning_rate: float = 1e-5, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ): + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" + # disable HF thousand warnings + warnings.simplefilter("ignore") + # set os environ variable for multiprocesses + os.environ["PYTHONWARNINGS"] = "ignore" + + model = ( + self.backbones.get(backbone)() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone) + ) + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.save_hyperparameters() + + self.set_state(SpeechRecognitionBackboneState(backbone)) + self.set_state(CollateFn(DataCollatorCTCWithPadding(Wav2Vec2Processor.from_pretrained(backbone)))) + + def forward(self, batch: Dict[str, torch.Tensor]): + return self.model(batch["input_values"]) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) + + def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: + out = self.model(batch["input_values"], labels=batch["labels"]) + out["logs"] = {"loss": out.loss} + return out diff --git a/flash/core/adapter.py b/flash/core/adapter.py new file mode 100644 index 0000000000..c7557b1977 --- /dev/null +++ b/flash/core/adapter.py @@ -0,0 +1,162 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from typing import Any, Callable, Optional + +from torch import nn +from torch.utils.data import DataLoader, Sampler + +import flash +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task + + +class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module): + """The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular + provider within a :class:`~flash.core.model.Task`.""" + + @classmethod + @abstractmethod + def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter": + """Instantiate the adapter from the given :class:`~flash.core.model.Task`. + + This includes resolution / creation of backbones / heads and any other provider specific options. + """ + + def forward(self, x: Any) -> Any: + pass + + def training_step(self, batch: Any, batch_idx: int) -> Any: + pass + + def validation_step(self, batch: Any, batch_idx: int) -> None: + pass + + def test_step(self, batch: Any, batch_idx: int) -> None: + pass + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + pass + + def training_epoch_end(self, outputs) -> None: + pass + + def validation_epoch_end(self, outputs) -> None: + pass + + def test_epoch_end(self, outputs) -> None: + pass + + +class AdapterTask(Task): + """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` + and forwards all of the hooks. + + Args: + adapter: The :class:`~flash.core.adapter.Adapter` to wrap. + kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`. + """ + + def __init__(self, adapter: Adapter, **kwargs): + super().__init__(**kwargs) + + self.adapter = adapter + + @property + def backbone(self) -> nn.Module: + return self.adapter.backbone + + def forward(self, x: Any) -> Any: + return self.adapter.forward(x) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + return self.adapter.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + return self.adapter.test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def training_epoch_end(self, outputs) -> None: + return self.adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.adapter.test_epoch_end(outputs) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_train_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_val_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_test_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self.adapter.process_predict_dataset( + dataset, batch_size, num_workers, pin_memory, collate_fn, shuffle, drop_last, sampler + ) diff --git a/flash/core/classification.py b/flash/core/classification.py index 61ee005ba9..b11e714528 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -21,7 +21,7 @@ from flash.core.data.data_source import DefaultDataKeys, LabelsState from flash.core.data.process import Serializer from flash.core.model import Task -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires Classification, Classifications = None, None if _FIFTYONE_AVAILABLE: @@ -38,10 +38,10 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. class ClassificationTask(Task): - def __init__( self, *args, + num_classes: Optional[int] = None, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, multi_label: bool = False, @@ -49,7 +49,7 @@ def __init__( **kwargs, ) -> None: if metrics is None: - metrics = torchmetrics.Accuracy(subset_accuracy=multi_label) + metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy() if loss_fn is None: loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy @@ -195,6 +195,7 @@ class FiftyOneLabels(ClassificationSerializer): list of FiftyOne labels (False) """ + @requires("fiftyone") def __init__( self, labels: Optional[List[str]] = None, @@ -203,9 +204,6 @@ def __init__( store_logits: bool = False, return_filepath: bool = False, ): - if not _FIFTYONE_AVAILABLE: - raise ModuleNotFoundError("Please, run `pip install fiftyone`.") - if multi_label and threshold is None: threshold = 0.5 diff --git a/flash/core/data/auto_dataset.py b/flash/core/data/auto_dataset.py index 6d81266348..fcd03fb18c 100644 --- a/flash/core/data/auto_dataset.py +++ b/flash/core/data/auto_dataset.py @@ -20,7 +20,7 @@ import flash from flash.core.data.utils import CurrentRunningStageFuncContext -DATA_TYPE = TypeVar('DATA_TYPE') +DATA_TYPE = TypeVar("DATA_TYPE") class BaseAutoDataset(Generic[DATA_TYPE]): @@ -41,7 +41,7 @@ class BaseAutoDataset(Generic[DATA_TYPE]): def __init__( self, data: DATA_TYPE, - data_source: 'flash.core.data.data_source.DataSource', + data_source: "flash.core.data.data_source.DataSource", running_stage: RunningStage, ) -> None: super().__init__() @@ -68,11 +68,11 @@ def running_stage(self, running_stage: RunningStage) -> None: self.load_sample: Callable[[DATA_TYPE, Optional[Any]], Any] = getattr( self.data_source, DataPipeline._resolve_function_hierarchy( - 'load_sample', + "load_sample", self.data_source, self.running_stage, DataSource, - ) + ), ) def _call_load_sample(self, sample: Any) -> Any: @@ -89,8 +89,10 @@ def _call_load_sample(self, sample: Any) -> Any: class AutoDataset(BaseAutoDataset[Sequence], Dataset): - """The ``AutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.Dataset`. The `data` argument - must be a ``Sequence`` (it must have a length).""" + """The ``AutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.Dataset`. + + The `data` argument must be a ``Sequence`` (it must have a length). + """ def __getitem__(self, index: int) -> Any: return self._call_load_sample(self.data[index]) @@ -100,8 +102,10 @@ def __len__(self) -> int: class IterableAutoDataset(BaseAutoDataset[Iterable], IterableDataset): - """The ``IterableAutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.IterableDataset`. The `data` - argument must be an ``Iterable``.""" + """The ``IterableAutoDataset`` is a ``BaseAutoDataset`` and a :class:`~torch.utils.data.IterableDataset`. + + The `data` argument must be an ``Iterable``. + """ def __iter__(self): self.data_iter = iter(self.data) diff --git a/flash/core/data/base_viz.py b/flash/core/data/base_viz.py index 7d1128cf93..4f426ff014 100644 --- a/flash/core/data/base_viz.py +++ b/flash/core/data/base_viz.py @@ -22,8 +22,8 @@ class BaseVisualization(BaseDataFetcher): - """ - This Base Class is used to create visualization tool on top of :class:`~flash.core.data.process.Preprocess` hooks. + """This Base Class is used to create visualization tool on top of :class:`~flash.core.data.process.Preprocess` + hooks. Override any of the ``show_{preprocess_hook_name}`` to receive the associated data and visualize them. @@ -105,16 +105,13 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage): As the :class:`~flash.core.data.process.Preprocess` hooks are injected within the threaded workers of the DataLoader, the data won't be accessible when using ``num_workers > 0``. - """ def _show(self, running_stage: RunningStage, func_names_list: List[str]) -> None: self.show(self.batches[running_stage], running_stage, func_names_list) def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_list: List[str]) -> None: - """ - Override this function when you want to visualize a composition. - """ + """Override this function when you want to visualize a composition.""" # filter out the functions to visualise func_names_set: Set[str] = set(func_names_list) & set(_CALLBACK_FUNCS) if len(func_names_set) == 0: diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 12505bf181..dd0ed1e9dd 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -32,8 +32,8 @@ class _Sequential(torch.nn.Module): - """ - This class is used to chain 3 functions together for the _Preprocessor ``per_sample_transform`` function. + """This class is used to chain 3 functions together for the _Preprocessor ``per_sample_transform`` function. + 1. ``pre_tensor_transform`` 2. ``to_tensor_transform`` 3. ``post_tensor_transform`` @@ -41,7 +41,7 @@ class _Sequential(torch.nn.Module): def __init__( self, - preprocess: 'Preprocess', + preprocess: "Preprocess", pre_tensor_transform: Optional[Callable], to_tensor_transform: Optional[Callable], post_tensor_transform: Callable, @@ -101,11 +101,10 @@ def __str__(self) -> str: class _DeserializeProcessor(torch.nn.Module): - def __init__( self, - deserializer: 'Deserializer', - preprocess: 'Preprocess', + deserializer: "Deserializer", + preprocess: "Preprocess", pre_tensor_transform: Callable, to_tensor_transform: Callable, ): @@ -137,10 +136,9 @@ def forward(self, sample: str): class _SerializeProcessor(torch.nn.Module): - def __init__( self, - serializer: 'Serializer', + serializer: "Serializer", ): super().__init__() self.serializer = convert_to_modules(serializer) @@ -151,28 +149,28 @@ def forward(self, sample): class _Preprocessor(torch.nn.Module): """ - This class is used to encapsultate the following functions of a Preprocess Object: - Inside a worker: - per_sample_transform: Function to transform an individual sample - Inside a worker, it is actually make of 3 functions: - * pre_tensor_transform - * to_tensor_transform - * post_tensor_transform - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform - - Inside main process: - per_sample_transform: Function to transform an individual sample - * per_sample_transform_on_device - collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform_on_device + This class is used to encapsultate the following functions of a Preprocess Object: + Inside a worker: + per_sample_transform: Function to transform an individual sample + Inside a worker, it is actually make of 3 functions: + * pre_tensor_transform + * to_tensor_transform + * post_tensor_transform + collate: Function to merge sample into a batch + per_batch_transform: Function to transform an individual batch + * per_batch_transform + + Inside main process: + per_sample_transform: Function to transform an individual sample + * per_sample_transform_on_device + collate: Function to merge sample into a batch + per_batch_transform: Function to transform an individual batch + * per_batch_transform_on_device """ def __init__( self, - preprocess: 'Preprocess', + preprocess: "Preprocess", collate_fn: Callable, per_sample_transform: Union[Callable, _Sequential], per_batch_transform: Callable, @@ -229,7 +227,10 @@ def forward(self, samples: Sequence[Any]) -> Any: with self._collate_context: samples, metadata = self._extract_metadata(samples) - samples = self.collate_fn(samples) + try: + samples = self.collate_fn(samples, metadata) + except TypeError: + samples = self.collate_fn(samples) if metadata and isinstance(samples, dict): samples[DefaultDataKeys.METADATA] = metadata self.callback.on_collate(samples, self.stage) @@ -256,16 +257,16 @@ def __str__(self) -> str: class _Postprocessor(torch.nn.Module): - """ - This class is used to encapsultate the following functions of a Postprocess Object: - Inside main process: - per_batch_transform: Function to transform a batch - per_sample_transform: Function to transform an individual sample - uncollate_fn: Function to split a batch into samples - per_sample_transform: Function to transform an individual sample - save_fn: Function to save all data - save_per_sample: Function to save an individual sample - is_serving: Whether the Postprocessor is used in serving mode. + """This class is used to encapsultate the following functions of a Postprocess Object: + + Inside main process: + per_batch_transform: Function to transform a batch + per_sample_transform: Function to transform an individual sample + uncollate_fn: Function to split a batch into samples + per_sample_transform: Function to transform an individual sample + save_fn: Function to save all data + save_per_sample: Function to save an individual sample + is_serving: Whether the Postprocessor is used in serving mode. """ def __init__( @@ -289,9 +290,10 @@ def __init__( @staticmethod def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]: - if isinstance(batch, Mapping): - return batch, batch.get(DefaultDataKeys.METADATA, None) - return batch, None + metadata = None + if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch: + metadata = batch.pop(DefaultDataKeys.METADATA, None) + return batch, metadata def forward(self, batch: Sequence[Any]): batch, metadata = self._extract_metadata(batch) @@ -331,7 +333,6 @@ def __str__(self) -> str: def default_uncollate(batch: Any): """ This function is used to uncollate a batch into samples. - Examples: >>> a, b = default_uncollate(torch.rand((2,1))) """ @@ -346,7 +347,7 @@ def default_uncollate(batch: Any): if isinstance(batch, Mapping): return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())] - if isinstance(batch, tuple) and hasattr(batch, '_fields'): # namedtuple + if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple return [batch_type(*default_uncollate(sample)) for sample in zip(*batch)] if isinstance(batch, Sequence) and not isinstance(batch, str): diff --git a/flash/core/data/callback.py b/flash/core/data/callback.py index add1e70c2c..b4c2aa93ee 100644 --- a/flash/core/data/callback.py +++ b/flash/core/data/callback.py @@ -10,6 +10,16 @@ class FlashCallback(Callback): + """``FlashCallback`` is an extension of :class:`pytorch_lightning.callbacks.Callback`. + + A callback is a self-contained program that can be reused across projects. Flash and Lightning have a callback + system to execute callbacks when needed. Callbacks should capture any NON-ESSENTIAL logic that is NOT required for + your lightning module to run. + + Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer:: + + trainer = Trainer(callbacks=[MyCustomCallback()]) + """ def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: """Called once a sample has been loaded using ``load_sample``.""" @@ -37,7 +47,6 @@ def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningSta class ControlFlow(FlashCallback): - def __init__(self, callbacks: List[FlashCallback]): self._callbacks = callbacks @@ -72,8 +81,7 @@ def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningSta class BaseDataFetcher(FlashCallback): - """ - This class is used to profile :class:`~flash.core.data.process.Preprocess` hook outputs. + """This class is used to profile :class:`~flash.core.data.process.Preprocess` hook outputs. By default, the callback won't profile the data being processed as it may lead to ``OOMError``. @@ -155,7 +163,6 @@ def from_inputs( 'val': {}, 'predict': {} } - """ def __init__(self, enabled: bool = False): @@ -195,12 +202,12 @@ def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningSta @contextmanager def enable(self): - """This function is used to enable to BaseDataFetcher""" + """This function is used to enable to BaseDataFetcher.""" self.enabled = True yield self.enabled = False - def attach_to_preprocess(self, preprocess: 'flash.core.data.process.Preprocess') -> None: + def attach_to_preprocess(self, preprocess: "flash.core.data.process.Preprocess") -> None: preprocess.add_callbacks([self]) self._preprocess = preprocess diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 5e6d1c8aab..c0ea53dd98 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -14,7 +14,20 @@ import json import os import platform -from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) import numpy as np import pytorch_lightning as pl @@ -25,6 +38,7 @@ from torch.utils.data.dataset import IterableDataset, random_split, Subset from torch.utils.data.sampler import Sampler +import flash from flash.core.data.auto_dataset import BaseAutoDataset, IterableAutoDataset from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher @@ -32,7 +46,7 @@ from flash.core.data.data_source import DataSource, DefaultDataSources, LabelStudioDataSource from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, requires if _FIFTYONE_AVAILABLE and TYPE_CHECKING: from fiftyone.core.collections import SampleCollection @@ -84,13 +98,16 @@ def __init__( postprocess: Optional[Postprocess] = None, data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, - batch_size: int = 1, + batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, ) -> None: super().__init__() + if flash._IS_TESTING and torch.cuda.is_available(): + batch_size = 16 + self._data_source: DataSource = data_source self._preprocess: Optional[Preprocess] = preprocess self._postprocess: Optional[Postprocess] = postprocess @@ -124,7 +141,7 @@ def __init__( # TODO: figure out best solution for setting num_workers if num_workers is None: - if platform.system() == "Darwin" or platform.system() == "Windows": + if platform.system() in ("Darwin", "Windows"): num_workers = 0 else: num_workers = os.cpu_count() @@ -135,22 +152,22 @@ def __init__( @property def train_dataset(self) -> Optional[Dataset]: - """This property returns the train dataset""" + """This property returns the train dataset.""" return self._train_ds @property def val_dataset(self) -> Optional[Dataset]: - """This property returns the validation dataset""" + """This property returns the validation dataset.""" return self._val_ds @property def test_dataset(self) -> Optional[Dataset]: - """This property returns the test dataset""" + """This property returns the test dataset.""" return self._test_ds @property def predict_dataset(self) -> Optional[Dataset]: - """This property returns the predict dataset""" + """This property returns the predict dataset.""" return self._predict_ds @property @@ -163,8 +180,8 @@ def viz(self, viz: BaseVisualization) -> None: @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - """ - This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`. + """This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`. + Override with your custom one. """ return BaseDataFetcher() @@ -189,9 +206,7 @@ def _reset_iterator(self, stage: str) -> Iterable[Any]: return iterator def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None: - """ - This function is used to handle transforms profiling for batch visualization. - """ + """This function is used to handle transforms profiling for batch visualization.""" # don't show in CI if os.getenv("FLASH_TESTING", "0") == "1": return None @@ -218,22 +233,22 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool if reset: self.data_fetcher.batches[stage] = {} - def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_train_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the train dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING] self._show_batch(stage_name, hooks_names, reset=reset) - def show_val_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_val_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the validation dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING] self._show_batch(stage_name, hooks_names, reset=reset) - def show_test_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_test_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the test dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.TESTING] self._show_batch(stage_name, hooks_names, reset=reset) - def show_predict_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: + def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None: """This function is used to visualize a batch from the predict dataloader.""" stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING] self._show_batch(stage_name, hooks_names, reset=reset) @@ -254,16 +269,16 @@ def set_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, val def set_running_stages(self): if self._train_ds: - self.set_dataset_attribute(self._train_ds, 'running_stage', RunningStage.TRAINING) + self.set_dataset_attribute(self._train_ds, "running_stage", RunningStage.TRAINING) if self._val_ds: - self.set_dataset_attribute(self._val_ds, 'running_stage', RunningStage.VALIDATING) + self.set_dataset_attribute(self._val_ds, "running_stage", RunningStage.VALIDATING) if self._test_ds: - self.set_dataset_attribute(self._test_ds, 'running_stage', RunningStage.TESTING) + self.set_dataset_attribute(self._test_ds, "running_stage", RunningStage.TESTING) if self._predict_ds: - self.set_dataset_attribute(self._predict_ds, 'running_stage', RunningStage.PREDICTING) + self.set_dataset_attribute(self._predict_ds, "running_stage", RunningStage.PREDICTING) def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]: if isinstance(dataset, (BaseAutoDataset, SplitDataset)): @@ -272,37 +287,84 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> def _train_dataloader(self) -> DataLoader: train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds shuffle: bool = False + collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING) + if isinstance(train_ds, IterableAutoDataset): + drop_last = False + else: + drop_last = len(train_ds) > self.batch_size + pin_memory = True + if self.sampler is None: + sampler = None shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset)) + else: + sampler = self.sampler(train_ds) + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_train_dataset( + train_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn, + sampler=sampler, + ) + return DataLoader( train_ds, batch_size=self.batch_size, shuffle=shuffle, - sampler=self.sampler, + sampler=sampler, num_workers=self.num_workers, - pin_memory=True, - drop_last=True, - collate_fn=self._resolve_collate_fn(train_ds, RunningStage.TRAINING) + pin_memory=pin_memory, + drop_last=drop_last, + collate_fn=collate_fn, ) def _val_dataloader(self) -> DataLoader: val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds + collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) + pin_memory = True + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_val_dataset( + val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + ) + return DataLoader( val_ds, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, - collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) + pin_memory=pin_memory, + collate_fn=collate_fn, ) def _test_dataloader(self) -> DataLoader: test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds + collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING) + pin_memory = True + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_test_dataset( + test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + ) + return DataLoader( test_ds, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, - collate_fn=self._resolve_collate_fn(test_ds, RunningStage.TESTING) + pin_memory=pin_memory, + collate_fn=collate_fn, ) def _predict_dataloader(self) -> DataLoader: @@ -311,12 +373,21 @@ def _predict_dataloader(self) -> DataLoader: batch_size = self.batch_size else: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) + + collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) + pin_memory = True + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_predict_dataset( + predict_ds, + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + ) + return DataLoader( - predict_ds, - batch_size=batch_size, - num_workers=self.num_workers, - pin_memory=True, - collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) + predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn ) @property @@ -326,6 +397,13 @@ def num_classes(self) -> Optional[int]: n_cls_test = getattr(self.test_dataset, "num_classes", None) return n_cls_train or n_cls_val or n_cls_test + @property + def multi_label(self) -> Optional[bool]: + multi_label_train = getattr(self.train_dataset, "multi_label", None) + multi_label_val = getattr(self.val_dataset, "multi_label", None) + multi_label_test = getattr(self.test_dataset, "multi_label", None) + return multi_label_train or multi_label_val or multi_label_test + @property def data_source(self) -> Optional[DataSource]: return self._data_source @@ -392,9 +470,9 @@ def from_data_source( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given inputs to :meth:`~flash.core.data.data_source.DataSource.load_data` (``train_data``, ``val_data``, ``test_data``, ``predict_data``). The data source will be resolved from the instantiated @@ -428,7 +506,7 @@ def from_data_source( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -492,9 +570,9 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` @@ -521,7 +599,7 @@ def from_folders( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -575,13 +653,13 @@ def from_files( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': - """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files using - the :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.FILES` - from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + ) -> "DataModule": + """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given sequences of files + using the :class:`~flash.core.data.data_source.DataSource` of name + :attr:`~flash.core.data.data_source.DefaultDataSources.FILES` from the passed or constructed + :class:`~flash.core.data.process.Preprocess`. Args: train_files: A sequence of files to use as the train inputs. @@ -607,7 +685,7 @@ def from_files( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -662,9 +740,9 @@ def from_tensors( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given tensors using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.TENSOR` @@ -694,7 +772,7 @@ def from_tensors( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -749,9 +827,9 @@ def from_numpy( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given numpy array using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` @@ -781,7 +859,7 @@ def from_numpy( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -835,9 +913,10 @@ def from_json( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, + field: Optional[str] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.JSON` @@ -866,7 +945,8 @@ def from_json( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + field: To specify the field that holds the data in the JSON file. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -883,13 +963,35 @@ def from_json( "to_tensor_transform": torch.as_tensor, }, ) + + # In the case where the data is of the form: + # { + # "version": 0.0.x, + # "data": [ + # { + # "input_field" : "input_data", + # "target_field" : "target_output" + # }, + # ... + # ] + # } + + data_module = DataModule.from_json( + "input", + "target", + train_file="train_data.json", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + feild="data" + ) """ return cls.from_data_source( DefaultDataSources.JSON, - (train_file, input_fields, target_fields), - (val_file, input_fields, target_fields), - (test_file, input_fields, target_fields), - (predict_file, input_fields, target_fields), + (train_file, input_fields, target_fields, field), + (val_file, input_fields, target_fields, field), + (test_file, input_fields, target_fields, field), + (predict_file, input_fields, target_fields, field), train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, @@ -921,9 +1023,9 @@ def from_csv( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given CSV files using the :class:`~flash.core.data.data_source.DataSource` of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` @@ -952,7 +1054,7 @@ def from_csv( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -1005,12 +1107,12 @@ def from_datasets( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - sampler: Optional[Sampler] = None, + sampler: Optional[Type[Sampler]] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the :class:`~flash.core.data.data_source.DataSource` - of name :attr:`~flash.core.data.data_source.DefaultDataSources.DATASET` + of name :attr:`~flash.core.data.data_source.DefaultDataSources.DATASETS` from the passed or constructed :class:`~flash.core.data.process.Preprocess`. Args: @@ -1034,7 +1136,7 @@ def from_datasets( val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -1051,7 +1153,7 @@ def from_datasets( ) """ return cls.from_data_source( - DefaultDataSources.DATASET, + DefaultDataSources.DATASETS, train_dataset, val_dataset, test_dataset, @@ -1070,6 +1172,7 @@ def from_datasets( ) @classmethod + @requires("fiftyone") def from_fiftyone( cls, train_dataset: Optional[SampleCollection] = None, @@ -1086,7 +1189,7 @@ def from_fiftyone( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given FiftyOne Datasets using the :class:`~flash.core.data.data_source.DataSource` of name @@ -1133,9 +1236,6 @@ def from_fiftyone( }, ) """ - if not _FIFTYONE_AVAILABLE: - raise ModuleNotFoundError("Please, `pip install fiftyone`.") - return cls.from_data_source( DefaultDataSources.FIFTYONE, train_dataset, @@ -1170,7 +1270,7 @@ def from_labelstudio( num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given export file and data directory using the :class:`~flash.core.data.data_source.DataSource` of name @@ -1212,10 +1312,10 @@ def from_labelstudio( ) """ data = { - 'data_folder': data_folder, - 'export_json': export_json, - 'split': val_split, - 'multi_label': preprocess_kwargs.get('multi_label', False) + "data_folder": data_folder, + "export_json": export_json, + "split": val_split, + "multi_label": preprocess_kwargs.get("multi_label", False), } return cls.from_data_source( DefaultDataSources.LABELSTUDIO, diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 2d4a2bf1d7..d00618ff05 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -49,7 +49,8 @@ def set_state(self, state: ProcessState): else: rank_zero_warn( f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will" - " only have an effect when a new data pipeline is created.", UserWarning + " only have an effect when a new data pipeline is created.", + UserWarning, ) def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: @@ -124,12 +125,10 @@ def example_input(self) -> str: @staticmethod def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool: - """ - Cropped Version of - https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py - """ + """Cropped Version of https://github.com/PyTorchLightning/pytorch- + lightning/blob/master/pytorch_lightning/utilities/model_helpers.py.""" - current_method_name = method_name if prefix is None else f'{prefix}_{method_name}' + current_method_name = method_name if prefix is None else f"{prefix}_{method_name}" if not hasattr(process_obj, current_method_name): return False @@ -140,15 +139,13 @@ def _is_overriden(method_name: str, process_obj, super_obj: Any, prefix: Optiona def _is_overriden_recursive( cls, method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None ) -> bool: - """ - Cropped Version of - https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py - """ + """Cropped Version of https://github.com/PyTorchLightning/pytorch- + lightning/blob/master/pytorch_lightning/utilities/model_helpers.py.""" assert isinstance(process_obj, super_obj) if prefix is None and not hasattr(super_obj, method_name): raise MisconfigurationException(f"This function doesn't belong to the parent class {super_obj}") - current_method_name = method_name if prefix is None else f'{prefix}_{method_name}' + current_method_name = method_name if prefix is None else f"{prefix}_{method_name}" if not hasattr(process_obj, current_method_name): return DataPipeline._is_overriden_recursive(method_name, process_obj, super_obj) @@ -167,8 +164,10 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]: def deserialize_processor(self) -> _DeserializeProcessor: return self._create_collate_preprocessors(RunningStage.PREDICTING)[0] - def worker_preprocessor(self, running_stage: RunningStage, is_serving: bool = False) -> _Preprocessor: - return self._create_collate_preprocessors(running_stage, is_serving=is_serving)[1] + def worker_preprocessor( + self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False + ) -> _Preprocessor: + return self._create_collate_preprocessors(running_stage, collate_fn=collate_fn, is_serving=is_serving)[1] def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor: return self._create_collate_preprocessors(running_stage)[2] @@ -189,19 +188,19 @@ def _resolve_function_hierarchy( prefixes = [] if stage in (RunningStage.TRAINING, RunningStage.TUNING): - prefixes += ['train', 'fit'] + prefixes += ["train", "fit"] elif stage == RunningStage.VALIDATING: - prefixes += ['val', 'fit'] + prefixes += ["val", "fit"] elif stage == RunningStage.TESTING: - prefixes += ['test'] + prefixes += ["test"] elif stage == RunningStage.PREDICTING: - prefixes += ['predict'] + prefixes += ["predict"] prefixes += [None] for prefix in prefixes: if cls._is_overriden(function_name, process_obj, object_type, prefix=prefix): - return function_name if prefix is None else f'{prefix}_{function_name}' + return function_name if prefix is None else f"{prefix}_{function_name}" return function_name @@ -226,8 +225,7 @@ def _create_collate_preprocessors( preprocess._default_collate = collate_fn func_names: Dict[str, str] = { - k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) - for k in self.PREPROCESS_FUNCS + k: self._resolve_function_hierarchy(k, preprocess, stage, Preprocess) for k in self.PREPROCESS_FUNCS } collate_fn: Callable = getattr(preprocess, func_names["collate"]) @@ -247,8 +245,8 @@ def _create_collate_preprocessors( is_per_overriden = per_batch_transform_overriden and per_sample_transform_on_device_overriden if collate_in_worker_from_transform is None and is_per_overriden: raise MisconfigurationException( - f'{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` ' - f'are mutually exclusive for stage {stage}' + f"{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` " + f"are mutually exclusive for stage {stage}" ) if isinstance(collate_in_worker_from_transform, bool): @@ -258,9 +256,9 @@ def _create_collate_preprocessors( per_sample_transform_on_device_overriden, collate_fn ) - worker_collate_fn = worker_collate_fn.collate_fn if isinstance( - worker_collate_fn, _Preprocessor - ) else worker_collate_fn + worker_collate_fn = ( + worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _Preprocessor) else worker_collate_fn + ) assert_contains_tensor = self._is_overriden_recursive( "to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage] @@ -269,26 +267,29 @@ def _create_collate_preprocessors( deserialize_processor = _DeserializeProcessor( self._deserializer, preprocess, - getattr(preprocess, func_names['pre_tensor_transform']), - getattr(preprocess, func_names['to_tensor_transform']), + getattr(preprocess, func_names["pre_tensor_transform"]), + getattr(preprocess, func_names["to_tensor_transform"]), ) worker_preprocessor = _Preprocessor( - preprocess, worker_collate_fn, + preprocess, + worker_collate_fn, _Sequential( preprocess, - None if is_serving else getattr(preprocess, func_names['pre_tensor_transform']), - None if is_serving else getattr(preprocess, func_names['to_tensor_transform']), - getattr(preprocess, func_names['post_tensor_transform']), + None if is_serving else getattr(preprocess, func_names["pre_tensor_transform"]), + None if is_serving else getattr(preprocess, func_names["to_tensor_transform"]), + getattr(preprocess, func_names["post_tensor_transform"]), stage, assert_contains_tensor=assert_contains_tensor, - ), getattr(preprocess, func_names['per_batch_transform']), stage + ), + getattr(preprocess, func_names["per_batch_transform"]), + stage, ) worker_preprocessor._original_collate_fn = original_collate_fn device_preprocessor = _Preprocessor( preprocess, device_collate_fn, - getattr(preprocess, func_names['per_sample_transform_on_device']), - getattr(preprocess, func_names['per_batch_transform_on_device']), + getattr(preprocess, func_names["per_sample_transform_on_device"]), + getattr(preprocess, func_names["per_batch_transform_on_device"]), stage, apply_per_sample_transform=device_collate_fn != self._identity, on_device=True, @@ -297,7 +298,7 @@ def _create_collate_preprocessors( @staticmethod def _model_transfer_to_device_wrapper( - func: Callable, preprocessor: _Preprocessor, model: 'Task', stage: RunningStage + func: Callable, preprocessor: _Preprocessor, model: "Task", stage: RunningStage ) -> Callable: if not isinstance(func, _StageOrchestrator): @@ -307,7 +308,7 @@ def _model_transfer_to_device_wrapper( return func @staticmethod - def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, model: 'Task') -> Callable: + def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, model: "Task") -> Callable: if not isinstance(func, _StageOrchestrator): _original = func @@ -318,24 +319,22 @@ def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, m return func @staticmethod - def _get_dataloader(model: 'Task', loader_name: str) -> Tuple[DataLoader, str]: + def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]: dataloader, attr_name = None, None if hasattr(model, loader_name): dataloader = getattr(model, loader_name) attr_name = loader_name - elif model.trainer and hasattr(model.trainer, 'datamodule') and model.trainer.datamodule: - dataloader = getattr(model, f'trainer.datamodule.{loader_name}', None) - attr_name = f'trainer.datamodule.{loader_name}' + elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule: + dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None) + attr_name = f"trainer.datamodule.{loader_name}" return dataloader, attr_name @staticmethod - def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None: - """ - This function is used to set the loader to model and/or datamodule - """ - *intermediates, final_name = loader_name.split('.') + def _set_loader(model: "Task", loader_name: str, new_loader: DataLoader) -> None: + """This function is used to set the loader to model and/or datamodule.""" + *intermediates, final_name = loader_name.split(".") curr_attr = model # This relies on python calling all non-integral types by reference. @@ -348,7 +347,7 @@ def _set_loader(model: 'Task', loader_name: str, new_loader: DataLoader) -> None def _attach_preprocess_to_model( self, - model: 'Task', + model: "Task", stage: Optional[RunningStage] = None, device_transform_only: bool = False, is_serving: bool = False, @@ -363,7 +362,7 @@ def _attach_preprocess_to_model( for stage in stages: - loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' + loader_name = f"{_STAGES_PREFIX[stage]}_dataloader" dataloader, whole_attr_name = self._get_dataloader(model, loader_name) @@ -387,8 +386,8 @@ def _attach_preprocess_to_model( if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - _, dl_args['collate_fn'], device_collate_fn = self._create_collate_preprocessors( - stage=stage, collate_fn=dl_args['collate_fn'], is_serving=is_serving + _, dl_args["collate_fn"], device_collate_fn = self._create_collate_preprocessors( + stage=stage, collate_fn=dl_args["collate_fn"], is_serving=is_serving ) if isinstance(dl_args["dataset"], IterableDataset): @@ -411,8 +410,8 @@ def _attach_preprocess_to_model( self._set_loader(model, whole_attr_name, dataloader) - model.transfer_batch_to_device = ( - self._model_transfer_to_device_wrapper(model.transfer_batch_to_device, device_collate_fn, model, stage) + model.transfer_batch_to_device = self._model_transfer_to_device_wrapper( + model.transfer_batch_to_device, device_collate_fn, model, stage ) def _create_uncollate_postprocessors( @@ -453,10 +452,10 @@ def _create_uncollate_postprocessors( def _attach_postprocess_to_model( self, - model: 'Task', + model: "Task", stage: RunningStage, is_serving: bool = False, - ) -> 'Task': + ) -> "Task": model.predict_step = self._model_predict_step_wrapper( model.predict_step, self._create_uncollate_postprocessors(stage, is_serving=is_serving), model ) @@ -464,7 +463,7 @@ def _attach_postprocess_to_model( def _attach_to_model( self, - model: 'Task', + model: "Task", stage: RunningStage = None, is_serving: bool = False, ): @@ -474,13 +473,13 @@ def _attach_to_model( if not stage or stage == RunningStage.PREDICTING: self._attach_postprocess_to_model(model, RunningStage.PREDICTING, is_serving=is_serving) - def _detach_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): + def _detach_from_model(self, model: "Task", stage: Optional[RunningStage] = None): self._detach_preprocessing_from_model(model, stage) if not stage or stage == RunningStage.PREDICTING: self._detach_postprocess_from_model(model) - def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[RunningStage] = None): + def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[RunningStage] = None): if not stage: stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] elif isinstance(stage, RunningStage): @@ -499,7 +498,7 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin if not device_collate: device_collate = self._identity - loader_name = f'{_STAGES_PREFIX[stage]}_dataloader' + loader_name = f"{_STAGES_PREFIX[stage]}_dataloader" dataloader, whole_attr_name = self._get_dataloader(model, loader_name) @@ -521,11 +520,11 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin if isinstance(loader, DataLoader): dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")} - if isinstance(dl_args['collate_fn'], _Preprocessor): + if isinstance(dl_args["collate_fn"], _Preprocessor): dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn if isinstance(dl_args["dataset"], IterableAutoDataset): - del dl_args['sampler'] + del dl_args["sampler"] del dl_args["batch_sampler"] @@ -542,9 +541,9 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin self._set_loader(model, whole_attr_name, dataloader) @staticmethod - def _detach_postprocess_from_model(model: 'Task'): + def _detach_postprocess_from_model(model: "Task"): - if hasattr(model.predict_step, '_original'): + if hasattr(model.predict_step, "_original"): # don't delete the predict_step here since we don't know # if any other pipeline is attached which may rely on this! model.predict_step = model.predict_step._original @@ -574,10 +573,10 @@ class _StageOrchestrator: RunningStage.VALIDATING: RunningStage.VALIDATING, RunningStage.TESTING: RunningStage.TESTING, RunningStage.PREDICTING: RunningStage.PREDICTING, - RunningStage.TUNING: RunningStage.TUNING + RunningStage.TUNING: RunningStage.TUNING, } - def __init__(self, func_to_wrap: Callable, model: 'Task') -> None: + def __init__(self, func_to_wrap: Callable, model: "Task") -> None: self.func = func_to_wrap self._stage_mapping = {k: None for k in RunningStage} diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index b43c42766c..085e9510cb 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -14,7 +14,9 @@ import json import os import typing +import warnings from dataclasses import dataclass +from functools import partial from inspect import signature from pathlib import Path from typing import ( @@ -57,7 +59,9 @@ else: fol = None from copy import deepcopy -from flash.core.utilities.imports import _TEXT_AVAILABLE, _PYTORCHVIDEO_AVAILABLE + +from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE + if _PYTORCHVIDEO_AVAILABLE: from torchvision.datasets.folder import default_loader @@ -138,10 +142,8 @@ def has_len(data: Union[Sequence[Any], Iterable[Any]]) -> bool: @dataclass(unsafe_hash=True, frozen=True) class LabelsState(ProcessState): - """ - A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, - a mapping from class index to label. - """ + """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to + label.""" labels: Optional[Sequence[str]] @@ -162,7 +164,7 @@ class DefaultDataSources(LightningEnum): TENSORS = "tensors" CSV = "csv" JSON = "json" - DATASET = "dataset" + DATASETS = "datasets" FIFTYONE = "fiftyone" LABELSTUDIO = "labelstudio" @@ -185,16 +187,26 @@ def __hash__(self) -> int: return hash(self.value) +class BaseDataFormat(LightningEnum): + """The base class for creating ``data_format`` for :class:`~flash.core.data.data_source.DataSource`.""" + + def __hash__(self) -> int: + return hash(self.value) + + class MockDataset: - """The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. This is passed to + """The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. + + This is passed to :meth:`~flash.core.data.data_source.DataSource.load_data` so that attributes can be set on the generated - data set.""" + data set. + """ def __init__(self): self.metadata = {} def __setattr__(self, key, value): - if key != 'metadata': + if key != "metadata": self.metadata[key] = value object.__setattr__(self, key, value) @@ -203,9 +215,12 @@ def __setattr__(self, key, value): class DataSource(Generic[DATA_TYPE], Properties, Module): - """The ``DataSource`` class encapsulates two hooks: ``load_data`` and ``load_sample``. The + """The ``DataSource`` class encapsulates two hooks: ``load_data`` and ``load_sample``. + + The :meth:`~flash.core.data.data_source.DataSource.to_datasets` method can then be used to automatically construct data - sets from the hooks.""" + sets from the hooks. + """ @staticmethod def load_data( @@ -272,10 +287,10 @@ def to_datasets( test_data: Optional[DATA_TYPE] = None, predict_data: Optional[DATA_TYPE] = None, ) -> Tuple[Optional[BaseAutoDataset], ...]: - """Construct data sets (of type :class:`~flash.core.data.auto_dataset.BaseAutoDataset`) from this data source by - calling :meth:`~flash.core.data.data_source.DataSource.load_data` with each of the ``*_data`` arguments. If an - argument is given as ``None`` then no dataset will be created for that stage (``train``, ``val``, ``test``, - ``predict``). + """Construct data sets (of type :class:`~flash.core.data.auto_dataset.BaseAutoDataset`) from this data + source by calling :meth:`~flash.core.data.data_source.DataSource.load_data` with each of the ``*_data`` + arguments. If an argument is given as ``None`` then no dataset will be created for that stage (``train``, + ``val``, ``test``, ``predict``). Args: train_data: The input to :meth:`~flash.core.data.data_source.DataSource.load_data` to use to create the @@ -388,10 +403,9 @@ def load_data( inputs, targets = data if targets is None: return self.predict_load_data(data) - return [{ - DefaultDataKeys.INPUT: input, - DefaultDataKeys.TARGET: target - } for input, target in zip(inputs, targets)] + return [ + {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in zip(inputs, targets) + ] @staticmethod def predict_load_data(data: Sequence[SEQUENCE_DATA_TYPE]) -> Sequence[Mapping[str, Any]]: @@ -409,10 +423,16 @@ class PathsDataSource(SequenceDataSource): :class:`~flash.core.data.data_source.LabelsState`. """ - def __init__(self, extensions: Optional[Tuple[str, ...]] = None, labels: Optional[Sequence[str]] = None): + def __init__( + self, + extensions: Optional[Tuple[str, ...]] = None, + loader: Optional[Callable[[str], Any]] = None, + labels: Optional[Sequence[str]] = None, + ): super().__init__(labels=labels) self.extensions = extensions + self.loader = loader @staticmethod def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: @@ -437,9 +457,9 @@ def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool: # data is not path-like (e.g. it may be a list of paths) return False - def load_data(self, - data: Union[str, Tuple[List[str], List[Any]]], - dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + def load_data( + self, data: Union[str, Tuple[List[str], List[Any]]], dataset: Optional[Any] = None + ) -> Sequence[Mapping[str, Any]]: if self.isdir(data): classes, class_to_idx = self.find_classes(data) if not classes: @@ -458,9 +478,9 @@ def load_data(self, ) ) - def predict_load_data(self, - data: Union[str, List[str]], - dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + def predict_load_data( + self, data: Union[str, List[str]], dataset: Optional[Any] = None + ) -> Sequence[Mapping[str, Any]]: if self.isdir(data): data = [os.path.join(data, file) for file in os.listdir(data)] @@ -476,6 +496,135 @@ def predict_load_data(self, ) ) + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + path = sample[DefaultDataKeys.INPUT] + + if self.loader is not None: + sample[DefaultDataKeys.INPUT] = self.loader(path) + + sample[DefaultDataKeys.METADATA] = { + "filepath": path, + } + return sample + + +class LoaderDataFrameDataSource( + DataSource[Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str], Optional[str]]] +): + def __init__(self, loader: Callable[[str], Any]): + super().__init__() + + self.loader = loader + + @staticmethod + def _walk_files(root: str) -> Iterator[str]: + for root, _, files in os.walk(root): + for file in files: + yield os.path.join(root, file) + + @staticmethod + def _default_resolver(root: str, id: str): + if os.path.isabs(id): + return id + + pattern = f"*{id}*" + + try: + return str(next(Path(root).rglob(pattern))) + except StopIteration: + raise ValueError( + f"Found no matches for pattern: {pattern} in directory: {root}. File IDs should uniquely identify the " + "file to load." + ) + + @staticmethod + def _resolve_file(resolver: Callable[[str, str], str], root: str, input_key: str, row: pd.Series) -> pd.Series: + row[input_key] = resolver(root, row[input_key]) + return row + + @staticmethod + def _resolve_target(label_to_class: Dict[str, int], target_key: str, row: pd.Series) -> pd.Series: + row[target_key] = label_to_class[row[target_key]] + return row + + @staticmethod + def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> pd.Series: + row[target_keys[0]] = [row[target_key] for target_key in target_keys] + return row + + def load_data( + self, + data: Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str], Optional[str]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + data, input_key, target_keys, root, resolver = data + + if isinstance(data, (str, Path)): + data = str(data) + data_frame = pd.read_csv(data) + if root is None: + root = os.path.dirname(data) + else: + data_frame = data + + if root is None: + root = "" + + if resolver is None: + warnings.warn("Using default resolver, this may take a while.", UserWarning) + resolver = self._default_resolver + + tqdm.pandas(desc="Resolving files") + data_frame = data_frame.progress_apply(partial(self._resolve_file, resolver, root, input_key), axis=1) + + if not self.predicting: + if isinstance(target_keys, List): + dataset.multi_label = True + dataset.num_classes = len(target_keys) + self.set_state(LabelsState(target_keys)) + data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1) + target_keys = target_keys[0] + else: + dataset.multi_label = False + if self.training: + labels = list(sorted(data_frame[target_keys].unique())) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + if labels is not None: + labels = labels.labels + label_to_class = {v: k for k, v in enumerate(labels)} + data_frame = data_frame.apply(partial(self._resolve_target, label_to_class, target_keys), axis=1) + + return [ + { + DefaultDataKeys.INPUT: row[input_key], + DefaultDataKeys.TARGET: row[target_keys], + } + for _, row in data_frame.iterrows() + ] + else: + return [ + { + DefaultDataKeys.INPUT: row[input_key], + } + for _, row in data_frame.iterrows() + ] + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + # TODO: simplify this duplicated code from PathsDataSource + path = sample[DefaultDataKeys.INPUT] + + if self.loader is not None: + sample[DefaultDataKeys.INPUT] = self.loader(path) + + sample[DefaultDataKeys.METADATA] = { + "filepath": path, + } + return sample + class TensorDataSource(SequenceDataSource[torch.Tensor]): """The ``TensorDataSource`` is a ``SequenceDataSource`` which expects the input to @@ -492,15 +641,15 @@ class FiftyOneDataSource(DataSource[SampleCollection]): :meth:`~flash.core.data.data_source.DataSource.load_data` to be a ``fiftyone.core.collections.SampleCollection``.""" def __init__(self, label_field: str = "ground_truth"): - if not _FIFTYONE_AVAILABLE: - raise ModuleNotFoundError("Please, run `pip install fiftyone`.") super().__init__() self.label_field = label_field @property + @requires("fiftyone") def label_cls(self): return fol.Label + @requires("fiftyone") def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: self._validate(data) @@ -520,26 +669,29 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se def to_idx(t): return [class_to_idx[x] for x in t] + else: def to_idx(t): return class_to_idx[t] - return [{ - DefaultDataKeys.INPUT: f, - DefaultDataKeys.TARGET: to_idx(t), - } for f, t in zip(filepaths, targets)] + return [ + { + DefaultDataKeys.INPUT: f, + DefaultDataKeys.TARGET: to_idx(t), + } + for f, t in zip(filepaths, targets) + ] @staticmethod + @requires("fiftyone") def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] def _validate(self, data): label_type = data._get_label_field_type(self.label_field) if not issubclass(label_type, self.label_cls): - raise ValueError( - "Expected field '%s' to have type %s; found %s" % (self.label_field, self.label_cls, label_type) - ) + raise ValueError(f"Expected field '{self.label_field}' to have type {self.label_cls}; found {label_type}") def _get_classes(self, data): classes = data.classes.get(self.label_field, None) @@ -557,6 +709,7 @@ def _get_classes(self, data): class LabelStudioDataSource(DataSource): """The ``LabelStudioDatasource`` expects the input to :meth:`~flash.core.data.data_source.DataSource.load_data` to be a json export from label studio.""" + def __init__(self): super().__init__() self.results = [] @@ -567,36 +720,34 @@ def __init__(self): self.num_classes = 0 def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: - """ - Iterate through all tasks in exported data and construct train\test\val results - """ + """Iterate through all tasks in exported data and construct train\test\val results.""" if data and isinstance(data, dict): - self._data_folder = data.get('data_folder') - with open(data.get('export_json')) as f: + self._data_folder = data.get("data_folder") + with open(data.get("export_json")) as f: self._raw_data = json.load(f) - self.multi_label = data.get('multi_label') - self.split = data.get('split') + self.multi_label = data.get("multi_label") + self.split = data.get("split") for task in self._raw_data: - for annotation in task['annotations']: + for annotation in task["annotations"]: # extracting data types from tasks - [self.data_types.add(key) for key in task.get('data')] + [self.data_types.add(key) for key in task.get("data")] # Adding ground_truth annotation to separate dataset - result = annotation['result'] + result = annotation["result"] for res in result: - t = res['type'] - for label in res['value'][t]: + t = res["type"] + for label in res["value"][t]: # check if labeling result is a list of labels if isinstance(label, list) and not self.multi_label: for sublabel in label: self.classes.add(sublabel) temp = {} - temp['file_upload'] = task.get('file_upload') - temp['data'] = task.get('data') - temp['label'] = sublabel - temp['result'] = res.get('value') - if annotation['ground_truth']: + temp["file_upload"] = task.get("file_upload") + temp["data"] = task.get("data") + temp["label"] = sublabel + temp["result"] = res.get("value") + if annotation["ground_truth"]: self.test_results.append(temp) - elif not annotation['ground_truth']: + elif not annotation["ground_truth"]: self.results.append(temp) else: if isinstance(label, list): @@ -605,17 +756,18 @@ def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) - else: self.classes.add(label) temp = {} - temp['file_upload'] = task.get('file_upload') - temp['data'] = task.get('data') - temp['label'] = label - temp['result'] = res.get('value') - if annotation['ground_truth']: + temp["file_upload"] = task.get("file_upload") + temp["data"] = task.get("data") + temp["label"] = label + temp["result"] = res.get("value") + if annotation["ground_truth"]: self.test_results.append(temp) - elif not annotation['ground_truth']: + elif not annotation["ground_truth"]: self.results.append(temp) self.num_classes = len(self.classes) # splitting result to train and val sets import random + random.shuffle(self.results) data_length = len(self.results) prop = data_length - int(data_length * self.split) @@ -623,28 +775,26 @@ def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) - self.results = self.results[prop:] def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: - """ - Load 1 sample from dataset - """ + """Load 1 sample from dataset.""" # all other data types input_data = deepcopy(sample) try: - del input_data['label'] + del input_data["label"] except KeyError: # no label in input data pass - result = {DefaultDataKeys.INPUT: input_data, - DefaultDataKeys.TARGET: self._get_labels_from_sample(sample['label'])} + result = { + DefaultDataKeys.INPUT: input_data, + DefaultDataKeys.TARGET: self._get_labels_from_sample(sample["label"]), + } return result def generate_dataset( - self, - data: Optional[DATA_TYPE], - running_stage: RunningStage, + self, + data: Optional[DATA_TYPE], + running_stage: RunningStage, ) -> Optional[Union[AutoDataset, IterableAutoDataset]]: - """ - Generate dataset from loaded data - """ + """Generate dataset from loaded data.""" if running_stage in (RunningStage.TRAINING, RunningStage.TUNING): self.load_data(data) dataset = self.results @@ -663,9 +813,7 @@ def generate_dataset( return dataset def _get_labels_from_sample(self, labels): - """ - Translate string labels to int - """ + """Translate string labels to int.""" sorted_labels = sorted(list(self.classes)) if isinstance(labels, list): label = [] @@ -682,26 +830,20 @@ def __init__(self): pass def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: - """ - Load 1 sample from dataset - """ - if sample['file_upload']: - p = os.path.join(self._data_folder, sample['file_upload']) + """Load 1 sample from dataset.""" + if sample["file_upload"]: + p = os.path.join(self._data_folder, sample["file_upload"]) else: - for key in sample.get('data'): - p = sample.get('data').get(key) + for key in sample.get("data"): + p = sample.get("data").get(key) # loading image image = default_loader(p) - result = {DefaultDataKeys.INPUT: image, - DefaultDataKeys.TARGET: self._get_labels_from_sample(sample['label'])} + result = {DefaultDataKeys.INPUT: image, DefaultDataKeys.TARGET: self._get_labels_from_sample(sample["label"])} return result class LabelStudioTextDataSource(LabelStudioDataSource): - def __init__(self, - backbone=None, - max_length=128 - ): + def __init__(self, backbone=None, max_length=128): super().__init__() if backbone: if _TEXT_AVAILABLE: @@ -711,33 +853,24 @@ def __init__(self, self.max_length = max_length def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: - """ - Load 1 sample from dataset - """ + """Load 1 sample from dataset.""" if self.backbone: data = "" - for key in sample.get('data'): - data += sample.get('data').get(key) - tokenized_data = self.tokenizer(data, - max_length=self.max_length, - truncation=True, - padding="max_length") + for key in sample.get("data"): + data += sample.get("data").get(key) + tokenized_data = self.tokenizer(data, max_length=self.max_length, truncation=True, padding="max_length") for key in tokenized_data: tokenized_data[key] = torch.tensor(tokenized_data[key]) - tokenized_data['labels'] = self._get_labels_from_sample(sample['label']) + tokenized_data["labels"] = self._get_labels_from_sample(sample["label"]) # separate text data type block result = tokenized_data return result class LabelStudioVideoDataSource(LabelStudioDataSource): - def __init__(self, - video_sampler=None, - clip_sampler=None, - clip_duration=1, - decode_audio=False, - decoder: str = "pyav" - ): + def __init__( + self, video_sampler=None, clip_sampler=None, clip_duration=1, decode_audio=False, decoder: str = "pyav" + ): super().__init__() self.video_sampler = video_sampler or torch.utils.data.RandomSampler self.clip_sampler = clip_sampler @@ -745,9 +878,7 @@ def __init__(self, self.decoder = decoder def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] = None) -> Any: - """ - Load 1 sample from dataset - """ + """Load 1 sample from dataset.""" return sample def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: @@ -759,13 +890,19 @@ def load_data(self, data: Optional[Any] = None, dataset: Optional[Any] = None) - def convert_to_encodedvideo(self, dataset): if len(dataset) > 0: from pytorchvideo.data import EncodedVideoDataset + dataset = EncodedVideoDataset( - [(os.path.join(self._data_folder, sample['file_upload']), - {"label": self._get_labels_from_sample(sample['label'])}) for sample in dataset], + [ + ( + os.path.join(self._data_folder, sample["file_upload"]), + {"label": self._get_labels_from_sample(sample["label"])}, + ) + for sample in dataset + ], self.clip_sampler, decode_audio=self.decode_audio, decoder=self.decoder, ) return dataset else: - return [] \ No newline at end of file + return [] diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 2a94633821..c2ad49c390 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from abc import ABC, abstractclassmethod, abstractmethod from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence @@ -24,30 +25,27 @@ import flash from flash.core.data.batch import default_uncollate from flash.core.data.callback import FlashCallback -from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources +from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.properties import Properties +from flash.core.data.states import CollateFn from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext class BasePreprocess(ABC): - @abstractmethod def get_state_dict(self) -> Dict[str, Any]: - """ - Override this method to return state_dict - """ + """Override this method to return state_dict.""" @abstractclassmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - """ - Override this method to load from state_dict - """ + """Override this method to load from state_dict.""" class Preprocess(BasePreprocess, Properties): - """The :class:`~flash.core.data.process.Preprocess` encapsulates all the data processing logic that should run before - the data is passed to the model. It is particularly useful when you want to provide an end to end implementation - which works with 4 different stages: ``train``, ``validation``, ``test``, and inference (``predict``). + """The :class:`~flash.core.data.process.Preprocess` encapsulates all the data processing logic that should run + before the data is passed to the model. It is particularly useful when you want to provide an end to end + implementation which works with 4 different stages: ``train``, ``validation``, ``test``, and inference + (``predict``). The :class:`~flash.core.data.process.Preprocess` supports the following hooks: @@ -175,7 +173,6 @@ def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: elif self.predicting: # logic for predicting - """ def __init__( @@ -184,8 +181,8 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_sources: Optional[Dict[str, 'DataSource']] = None, - deserializer: Optional['Deserializer'] = None, + data_sources: Optional[Dict[str, "DataSource"]] = None, + deserializer: Optional["Deserializer"] = None, default_data_source: Optional[str] = None, ): super().__init__() @@ -213,8 +210,8 @@ def __init__( self._test_transform = convert_to_modules(self.test_transform) self._predict_transform = convert_to_modules(self.predict_transform) - if DefaultDataSources.DATASET not in data_sources: - data_sources[DefaultDataSources.DATASET] = DatasetDataSource() + if DefaultDataSources.DATASETS not in data_sources: + data_sources[DefaultDataSources.DATASETS] = DatasetDataSource() self._data_sources = data_sources self._deserializer = deserializer @@ -223,7 +220,7 @@ def __init__( self._default_collate: Callable = default_collate @property - def deserializer(self) -> Optional['Deserializer']: + def deserializer(self) -> Optional["Deserializer"]: return self._deserializer def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: @@ -245,19 +242,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): preprocess_state_dict["_meta"]["module"] = self.__module__ preprocess_state_dict["_meta"]["class_name"] = self.__class__.__name__ preprocess_state_dict["_meta"]["_state"] = self._state - destination['preprocess.state_dict'] = preprocess_state_dict - self._ddp_params_and_buffers_to_ignore = ['preprocess.state_dict'] + destination["preprocess.state_dict"] = preprocess_state_dict + self._ddp_params_and_buffers_to_ignore = ["preprocess.state_dict"] return super()._save_to_state_dict(destination, prefix, keep_vars) - def _check_transforms(self, transform: Optional[Dict[str, Callable]], - stage: RunningStage) -> Optional[Dict[str, Callable]]: + def _check_transforms( + self, transform: Optional[Dict[str, Callable]], stage: RunningStage + ) -> Optional[Dict[str, Callable]]: if transform is None: return transform if not isinstance(transform, Dict): raise MisconfigurationException( - "Transform should be a dict. " - f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." + "Transform should be a dict. " f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}." ) keys_diff = set(transform.keys()).difference(_PREPROCESS_FUNCS) @@ -272,8 +269,7 @@ def _check_transforms(self, transform: Optional[Dict[str, Callable]], if is_per_batch_transform_in and is_per_sample_transform_on_device_in: raise MisconfigurationException( - f'{transform}: `per_batch_transform` and `per_sample_transform_on_device` ' - f'are mutually exclusive.' + f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` " f"are mutually exclusive." ) collate_in_worker: Optional[bool] = None @@ -310,7 +306,7 @@ def current_transform(self) -> Callable: @property def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: - """ The transforms currently being used by this :class:`~flash.core.data.process.Preprocess`. """ + """The transforms currently being used by this :class:`~flash.core.data.process.Preprocess`.""" return { "train_transform": self.train_transform, "val_transform": self.val_transform, @@ -319,34 +315,37 @@ def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: } @property - def callbacks(self) -> List['FlashCallback']: + def callbacks(self) -> List["FlashCallback"]: if not hasattr(self, "_callbacks"): self._callbacks: List[FlashCallback] = [] return self._callbacks @callbacks.setter - def callbacks(self, callbacks: List['FlashCallback']): + def callbacks(self, callbacks: List["FlashCallback"]): self._callbacks = callbacks - def add_callbacks(self, callbacks: List['FlashCallback']): + def add_callbacks(self, callbacks: List["FlashCallback"]): _callbacks = [c for c in callbacks if c not in self._callbacks] self._callbacks.extend(_callbacks) @staticmethod def default_transforms() -> Optional[Dict[str, Callable]]: - """ The default transforms to use. Will be overridden by transforms passed to the ``__init__``. """ + """The default transforms to use. + + Will be overridden by transforms passed to the ``__init__``. + """ return None def pre_tensor_transform(self, sample: Any) -> Any: - """ Transforms to apply on a single object. """ + """Transforms to apply on a single object.""" return self.current_transform(sample) def to_tensor_transform(self, sample: Any) -> Tensor: - """ Transforms to convert single object to a tensor. """ + """Transforms to convert single object to a tensor.""" return self.current_transform(sample) def post_tensor_transform(self, sample: Tensor) -> Tensor: - """ Transforms to apply on a tensor. """ + """Transforms to apply on a tensor.""" return self.current_transform(sample) def per_batch_transform(self, batch: Any) -> Any: @@ -359,12 +358,24 @@ def per_batch_transform(self, batch: Any) -> Any: """ return self.current_transform(batch) - def collate(self, samples: Sequence) -> Any: - """ Transform to convert a sequence of samples to a collated batch. """ + def collate(self, samples: Sequence, metadata=None) -> Any: + """Transform to convert a sequence of samples to a collated batch.""" current_transform = self.current_transform if current_transform is self._identity: - return self._default_collate(samples) - return self.current_transform(samples) + current_transform = self._default_collate + + # the model can provide a custom ``collate_fn``. + collate_fn = self.get_state(CollateFn) + if collate_fn is not None: + collate_fn = collate_fn.collate_fn + else: + collate_fn = current_transform + # return collate_fn.collate_fn(samples) + + parameters = inspect.signature(collate_fn).parameters + if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters: + return collate_fn(samples, metadata) + return collate_fn(samples) def per_sample_transform_on_device(self, sample: Any) -> Any: """Transforms to apply to the data before the collation (per-sample basis). @@ -382,8 +393,7 @@ def per_sample_transform_on_device(self, sample: Any) -> Any: return self.current_transform(sample) def per_batch_transform_on_device(self, batch: Any) -> Any: - """ - Transforms to apply to a whole batch (if possible use this for efficiency). + """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: @@ -393,7 +403,8 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: return self.current_transform(batch) def available_data_sources(self) -> Sequence[str]: - """Get the list of available data source names for use with this :class:`~flash.core.data.process.Preprocess`. + """Get the list of available data source names for use with this + :class:`~flash.core.data.process.Preprocess`. Returns: The list of data source names. @@ -426,14 +437,13 @@ def data_source_of_name(self, data_source_name: str) -> DataSource: class DefaultPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - data_sources: Optional[Dict[str, 'DataSource']] = None, + data_sources: Optional[Dict[str, "DataSource"]] = None, default_data_source: Optional[str] = None, ): super().__init__( @@ -454,6 +464,8 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): class Postprocess(Properties): + """The :class:`~flash.core.data.process.Postprocess` encapsulates all the data processing logic that should run + after the model.""" def __init__(self, save_path: Optional[str] = None): super().__init__() @@ -463,6 +475,7 @@ def __init__(self, save_path: Optional[str] = None): @staticmethod def per_batch_transform(batch: Any) -> Any: """Transforms to apply on a whole batch before uncollation to individual samples. + Can involve both CPU and Device transforms as this is not applied in separate workers. """ return batch @@ -470,19 +483,22 @@ def per_batch_transform(batch: Any) -> Any: @staticmethod def per_sample_transform(sample: Any) -> Any: """Transforms to apply to a single sample after splitting up the batch. + Can involve both CPU and Device transforms as this is not applied in separate workers. """ return sample @staticmethod def uncollate(batch: Any) -> Any: - """Uncollates a batch into single samples. Tries to preserve the type whereever possible.""" + """Uncollates a batch into single samples. + + Tries to preserve the type whereever possible. + """ return default_uncollate(batch) @staticmethod def save_data(data: Any, path: str) -> None: - """Saves all data together to a single path. - """ + """Saves all data together to a single path.""" torch.save(data, path) @staticmethod @@ -492,7 +508,7 @@ def save_sample(sample: Any, path: str) -> None: # TODO: Are those needed ? def format_sample_save_path(self, path: str) -> str: - path = os.path.join(path, f'sample_{self._saved_samples}.ptl') + path = os.path.join(path, f"sample_{self._saved_samples}.ptl") self._saved_samples += 1 return path @@ -504,8 +520,8 @@ def _save_sample(self, sample: Any) -> None: class Serializer(Properties): - """A :class:`.Serializer` encapsulates a single ``serialize`` method which is used to convert the model output into - the desired output format when predicting.""" + """A :class:`.Serializer` encapsulates a single ``serialize`` method which is used to convert the model output + into the desired output format when predicting.""" def __init__(self): super().__init__() @@ -538,8 +554,8 @@ def __call__(self, sample: Any) -> Any: class SerializerMapping(Serializer): - """If the model output is a dictionary, then the :class:`.SerializerMapping` enables each entry in the dictionary - to be passed to it's own :class:`.Serializer`.""" + """If the model output is a dictionary, then the :class:`.SerializerMapping` enables each entry in the + dictionary to be passed to it's own :class:`.Serializer`.""" def __init__(self, serializers: Mapping[str, Serializer]): super().__init__() @@ -551,13 +567,13 @@ def serialize(self, sample: Any) -> Any: return {key: serializer.serialize(sample[key]) for key, serializer in self._serializers.items()} raise ValueError("The model output must be a mapping when using a SerializerMapping.") - def attach_data_pipeline_state(self, data_pipeline_state: 'flash.core.data.data_pipeline.DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): for serializer in self._serializers.values(): serializer.attach_data_pipeline_state(data_pipeline_state) class Deserializer(Properties): - """""" + """Deserializer.""" def deserialize(self, sample: Any) -> Any: # TODO: Output must be a tensor??? raise NotImplementedError @@ -573,7 +589,7 @@ def __call__(self, sample: Any) -> Any: class DeserializerMapping(Deserializer): # TODO: This is essentially a duplicate of SerializerMapping, should be abstracted away somewhere - """""" + """Deserializer Mapping.""" def __init__(self, deserializers: Mapping[str, Deserializer]): super().__init__() @@ -585,6 +601,6 @@ def deserialize(self, sample: Any) -> Any: return {key: deserializer.deserialize(sample[key]) for key, deserializer in self._deserializers.items()} raise ValueError("The model output must be a mapping when using a DeserializerMapping.") - def attach_data_pipeline_state(self, data_pipeline_state: 'flash.core.data.data_pipeline.DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): for deserializer in self._deserializers.values(): deserializer.attach_data_pipeline_state(data_pipeline_state) diff --git a/flash/core/data/properties.py b/flash/core/data/properties.py index 2d00ebf6c1..2a22846783 100644 --- a/flash/core/data/properties.py +++ b/flash/core/data/properties.py @@ -21,22 +21,19 @@ @dataclass(unsafe_hash=True, frozen=True) class ProcessState: - """ - Base class for all process states - """ + """Base class for all process states.""" -STATE_TYPE = TypeVar('STATE_TYPE', bound=ProcessState) +STATE_TYPE = TypeVar("STATE_TYPE", bound=ProcessState) class Properties: - def __init__(self): super().__init__() self._running_stage: Optional[RunningStage] = None self._current_fn: Optional[str] = None - self._data_pipeline_state: Optional['flash.core.data.data_pipeline.DataPipelineState'] = None + self._data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None self._state: Dict[Type[ProcessState], ProcessState] = {} def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: @@ -51,7 +48,7 @@ def set_state(self, state: ProcessState): if self._data_pipeline_state is not None: self._data_pipeline_state.set_state(state) - def attach_data_pipeline_state(self, data_pipeline_state: 'flash.core.data.data_pipeline.DataPipelineState'): + def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): self._data_pipeline_state = data_pipeline_state for state in self._state.values(): self._data_pipeline_state.set_state(state) diff --git a/flash/core/data/splits.py b/flash/core/data/splits.py index 45b833c852..5102b2a224 100644 --- a/flash/core/data/splits.py +++ b/flash/core/data/splits.py @@ -6,8 +6,7 @@ class SplitDataset(Dataset): - """ - SplitDataset is used to create Dataset Subset using indices. + """SplitDataset is used to create Dataset Subset using indices. Args: @@ -20,7 +19,6 @@ class SplitDataset(Dataset): split_ds = SplitDataset(dataset, indices=[10, 14, 25]) split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) - """ _INTERNAL_KEYS = ("dataset", "indices", "data") diff --git a/flash/core/data/states.py b/flash/core/data/states.py new file mode 100644 index 0000000000..de026f7d73 --- /dev/null +++ b/flash/core/data/states.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +from flash.core.data.properties import ProcessState + + +@dataclass(unsafe_hash=True, frozen=True) +class PreTensorTransform(ProcessState): + + transform: Optional[Callable] = None + + +@dataclass(unsafe_hash=True, frozen=True) +class ToTensorTransform(ProcessState): + + transform: Optional[Callable] = None + + +@dataclass(unsafe_hash=True, frozen=True) +class PostTensorTransform(ProcessState): + + transform: Optional[Callable] = None + + +@dataclass(unsafe_hash=True, frozen=True) +class CollateFn(ProcessState): + + collate_fn: Optional[Callable] = None diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index c07928ff7d..759c1bbc1e 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -21,9 +21,9 @@ class ApplyToKeys(nn.Sequential): - """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from the - input. When a single key is given, a single value will be passed to the transforms. When multiple keys are given, - the corresponding values will be passed to the transforms as a list. + """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from + the input. When a single key is given, a single value will be passed to the transforms. When multiple keys are + given, the corresponding values will be passed to the transforms as a list. Args: keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. @@ -31,7 +31,7 @@ class ApplyToKeys(nn.Sequential): """ def __init__(self, keys: Union[str, Sequence[str]], *args): - super().__init__(*[convert_to_modules(arg) for arg in args]) + super().__init__(*(convert_to_modules(arg) for arg in args)) if isinstance(keys, str): keys = [keys] self.keys = keys @@ -44,7 +44,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: inputs = inputs[0] outputs = super().forward(inputs) if not isinstance(outputs, Sequence): - outputs = (outputs, ) + outputs = (outputs,) result = {} result.update(x) @@ -72,7 +72,7 @@ class KorniaParallelTransforms(nn.Sequential): """ def __init__(self, *args): - super().__init__(*[convert_to_modules(arg) for arg in args]) + super().__init__(*(convert_to_modules(arg) for arg in args)) def forward(self, inputs: Any): result = list(inputs) if isinstance(inputs, Sequence) else [inputs] @@ -99,11 +99,14 @@ def forward(self, inputs: Any): def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]: - """Kornia transforms add batch dimension which need to be removed. This function removes that dimension and then - applies ``torch.utils.data._utils.collate.default_collate``.""" + """Kornia transforms add batch dimension which need to be removed. + + This function removes that dimension and then + applies ``torch.utils.data._utils.collate.default_collate``. + """ for sample in samples: for key in sample.keys(): - if torch.is_tensor(sample[key]): + if torch.is_tensor(sample[key]) and sample[key].ndim == 4: sample[key] = sample[key].squeeze(0) return default_collate(samples) @@ -112,8 +115,8 @@ def merge_transforms( base_transforms: Dict[str, Callable], additional_transforms: Dict[str, Callable], ) -> Dict[str, Callable]: - """Utility function to merge two transform dictionaries. For each hook, the ``additional_transforms`` will be be - called after the ``base_transforms``. + """Utility function to merge two transform dictionaries. For each hook, the ``additional_transforms`` will be + be called after the ``base_transforms``. Args: base_transforms: The base transforms dictionary. diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 63f28301d6..3779b7426e 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -24,10 +24,10 @@ from tqdm.auto import tqdm as tq _STAGES_PREFIX = { - RunningStage.TRAINING: 'train', - RunningStage.TESTING: 'test', - RunningStage.VALIDATING: 'val', - RunningStage.PREDICTING: 'predict' + RunningStage.TRAINING: "train", + RunningStage.TESTING: "test", + RunningStage.VALIDATING: "val", + RunningStage.PREDICTING: "predict", } _STAGES_PREFIX_VALUES = {"train", "test", "val", "predict"} @@ -61,7 +61,6 @@ class CurrentRunningStageContext: - def __init__(self, running_stage: RunningStage, obj: Any, reset: bool = True): self._running_stage = running_stage self._obj = obj @@ -79,7 +78,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class CurrentFuncContext: - def __init__(self, current_fn: str, obj: Any): self._current_fn = current_fn self._obj = obj @@ -96,7 +94,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class CurrentRunningStageFuncContext: - def __init__(self, running_stage: RunningStage, current_fn: str, obj: Any): self._running_stage = running_stage self._current_fn = current_fn @@ -117,8 +114,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: - """ - Download file with progressbar + """Download file with progressbar. # Code taken from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603 # __author__ = "github.com/ruxi" @@ -132,9 +128,9 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: if not os.path.exists(path): os.makedirs(path) - local_filename = os.path.join(path, url.split('/')[-1]) + local_filename = os.path.join(path, url.split("/")[-1]) r = requests.get(url, stream=True, verify=False) - file_size = int(r.headers['Content-Length']) if 'Content-Length' in r.headers else 0 + file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0 chunk_size = 1024 num_bars = int(file_size / chunk_size) if verbose: @@ -142,19 +138,19 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: print(dict(num_bars=num_bars)) if not os.path.exists(local_filename): - with open(local_filename, 'wb') as fp: + with open(local_filename, "wb") as fp: for chunk in tq( r.iter_content(chunk_size=chunk_size), total=num_bars, - unit='KB', + unit="KB", desc=local_filename, - leave=True # progressbar stays + leave=True, # progressbar stays ): fp.write(chunk) # type: ignore - if '.zip' in local_filename: + if ".zip" in local_filename: if os.path.exists(local_filename): - with zipfile.ZipFile(local_filename, 'r') as zip_ref: + with zipfile.ZipFile(local_filename, "r") as zip_ref: zip_ref.extractall(path) @@ -172,10 +168,7 @@ def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool: class FuncModule(torch.nn.Module): - """ - This class is used to wrap a callable within a nn.Module and - apply the wrapped function in `__call__` - """ + """This class is used to wrap a callable within a nn.Module and apply the wrapped function in `__call__`""" def __init__(self, func: Callable) -> None: super().__init__() diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 63eb209a00..854164fb15 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -21,7 +21,6 @@ class NoFreeze(BaseFinetuning): - def freeze_before_training(self, pl_module: LightningModule) -> None: pass @@ -36,19 +35,16 @@ def finetune_function( class FlashBaseFinetuning(BaseFinetuning): + """FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback. - def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): - r""" - - FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback. - - Override ``finetune_function`` to put your unfreeze logic. + Override :meth:`.finetune_function` to put your unfreeze logic. + """ + def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True): + """ Args: attr_names: Name(s) of the module attributes of the model to be frozen. - train_bn: Whether to train Batch Norm layer - """ super().__init__() @@ -70,7 +66,6 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O class Freeze(FlashBaseFinetuning): - def finetune_function( self, pl_module: LightningModule, @@ -82,7 +77,6 @@ def finetune_function( class FreezeUnfreeze(FlashBaseFinetuning): - def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_epoch: int = 10): super().__init__(attr_names, train_bn) self.unfreeze_epoch = unfreeze_epoch @@ -105,13 +99,12 @@ def finetune_function( class UnfreezeMilestones(FlashBaseFinetuning): - def __init__( self, attr_names: Union[str, List[str]] = "backbone", train_bn: bool = True, unfreeze_milestones: tuple = (5, 10), - num_layers: int = 5 + num_layers: int = 5, ): self.unfreeze_milestones = unfreeze_milestones self.num_layers = num_layers @@ -129,7 +122,7 @@ def finetune_function( if epoch == self.unfreeze_milestones[0]: # unfreeze num_layers last layers self.unfreeze_and_add_param_group( - modules=backbone_modules[-self.num_layers:], + modules=backbone_modules[-self.num_layers :], optimizer=optimizer, train_bn=self.train_bn, ) @@ -137,7 +130,7 @@ def finetune_function( elif epoch == self.unfreeze_milestones[1]: # unfreeze remaining layers self.unfreeze_and_add_param_group( - modules=backbone_modules[:-self.num_layers], + modules=backbone_modules[: -self.num_layers], optimizer=optimizer, train_bn=self.train_bn, ) @@ -147,7 +140,7 @@ def finetune_function( "no_freeze": NoFreeze, "freeze": Freeze, "freeze_unfreeze": FreezeUnfreeze, - "unfreeze_milestones": UnfreezeMilestones + "unfreeze_milestones": UnfreezeMilestones, } diff --git a/flash/core/integrations/fiftyone/utils.py b/flash/core/integrations/fiftyone/utils.py index 3c9bbb6d44..d5c8ae3fb3 100644 --- a/flash/core/integrations/fiftyone/utils.py +++ b/flash/core/integrations/fiftyone/utils.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Union import flash -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires Label, Session = None, None if _FIFTYONE_AVAILABLE: @@ -13,6 +13,7 @@ fo = None +@requires("fiftyone") def visualize( predictions: Union[List[Label], List[Dict[str, Label]]], filepaths: Optional[List[str]] = None, @@ -56,8 +57,6 @@ def visualize( Returns: a :class:`fiftyone:fiftyone.core.session.Session` """ - if not _FIFTYONE_AVAILABLE: - raise ModuleNotFoundError("Please, `pip install fiftyone`.") if flash._IS_TESTING: return None diff --git a/flash/core/integrations/icevision/__init__.py b/flash/core/integrations/icevision/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py new file mode 100644 index 0000000000..af95da9a52 --- /dev/null +++ b/flash/core/integrations/icevision/adapter.py @@ -0,0 +1,202 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, Dict, List, Optional + +from torch.utils.data import DataLoader, Sampler + +from flash.core.adapter import Adapter +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.integrations.icevision.transforms import to_icevision_record +from flash.core.model import Task +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.core.utilities.url_error import catch_url_error + +if _ICEVISION_AVAILABLE: + from icevision.metrics import COCOMetric + from icevision.metrics import Metric as IceVisionMetric +else: + COCOMetric = object + + +class SimpleCOCOMetric(COCOMetric): + def finalize(self) -> Dict[str, float]: + logs = super().finalize() + return { + "Precision (IoU=0.50:0.95,area=all)": logs["AP (IoU=0.50:0.95) area=all"], + "Recall (IoU=0.50:0.95,area=all,maxDets=100)": logs["AR (IoU=0.50:0.95) area=all maxDets=100"], + } + + +class IceVisionAdapter(Adapter): + """The ``IceVisionAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with IceVision.""" + + required_extras: str = "image" + + def __init__(self, model_type, model, icevision_adapter, backbone): + super().__init__() + + self.model_type = model_type + self.model = model + self.icevision_adapter = icevision_adapter + self.backbone = backbone + + @classmethod + @catch_url_error + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str, + head: str, + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + metadata = task.heads.get(head, with_metadata=True) + backbones = metadata["metadata"]["backbones"] + backbone_config = backbones.get(backbone)(pretrained) + model_type, model, icevision_adapter, backbone = metadata["fn"]( + backbone_config, + num_classes, + image_size=image_size, + **kwargs, + ) + icevision_adapter = icevision_adapter(model=model, metrics=metrics) + return cls(model_type, model, icevision_adapter, backbone) + + @staticmethod + def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): + metadata = metadata or [None] * len(samples) + return collate_fn( + [to_icevision_record({**sample, DefaultDataKeys.METADATA: m}) for sample, m in zip(samples, metadata)] + ) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Optional[Callable] = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.train_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Optional[Callable] = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Optional[Callable] = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.valid_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + data_loader = self.model_type.infer_dl( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + return data_loader + + def training_step(self, batch, batch_idx) -> Any: + return self.icevision_adapter.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.icevision_adapter.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.icevision_adapter.validation_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self(batch) + + def forward(self, batch: Any) -> Any: + return self.model_type.predict_from_dl(self.model, [batch], show_pbar=False) + + def training_epoch_end(self, outputs) -> None: + return self.icevision_adapter.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + return self.icevision_adapter.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + return self.icevision_adapter.validation_epoch_end(outputs) diff --git a/flash/core/integrations/icevision/backbones.py b/flash/core/integrations/icevision/backbones.py new file mode 100644 index 0000000000..dd30d3be56 --- /dev/null +++ b/flash/core/integrations/icevision/backbones.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from inspect import getmembers + +from torch import nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.backbones import BackboneConfig + + +def icevision_model_adapter(model_type): + class IceVisionModelAdapter(model_type.lightning.ModelAdapter): + def log(self, name, value, **kwargs): + if "prog_bar" not in kwargs: + kwargs["prog_bar"] = True + return super().log(name.split("/")[-1], value, **kwargs) + + return IceVisionModelAdapter + + +def load_icevision(adapter, model_type, backbone, num_classes, **kwargs): + model = model_type.model(backbone=backbone, num_classes=num_classes, **kwargs) + + backbone = nn.Module() + params = model.param_groups()[0] + for i, param in enumerate(params): + backbone.register_parameter(f"backbone_{i}", param) + + return model_type, model, adapter(model_type), backbone + + +def load_icevision_ignore_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + +def load_icevision_with_image_size(adapter, model_type, backbone, num_classes, image_size=None, **kwargs): + kwargs["img_size"] = image_size + return load_icevision(adapter, model_type, backbone, num_classes, **kwargs) + + +def get_backbones(model_type): + _BACKBONES = FlashRegistry("backbones") + + for backbone_name, backbone_config in getmembers(model_type.backbones, lambda x: isinstance(x, BackboneConfig)): + _BACKBONES( + backbone_config, + name=backbone_name, + ) + return _BACKBONES diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py new file mode 100644 index 0000000000..80ce622616 --- /dev/null +++ b/flash/core/integrations/icevision/data.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type + +import numpy as np + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.integrations.icevision.transforms import from_icevision_record +from flash.core.utilities.imports import _ICEVISION_AVAILABLE +from flash.image.data import ImagePathsDataSource + +if _ICEVISION_AVAILABLE: + from icevision.core.record import BaseRecord + from icevision.core.record_components import ClassMapRecordComponent, ImageRecordComponent, tasks + from icevision.data.data_splitter import SingleSplitSplitter + from icevision.parsers.parser import Parser + + +class IceVisionPathsDataSource(ImagePathsDataSource): + def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return super().predict_load_data(data, dataset) + + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + record = sample[DefaultDataKeys.INPUT].load() + return from_icevision_record(record) + + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + image = np.array(sample[DefaultDataKeys.INPUT]) + record = BaseRecord([ImageRecordComponent()]) + + record.set_img(image) + record.add_component(ClassMapRecordComponent(task=tasks.detection)) + return from_icevision_record(record) + + +class IceVisionParserDataSource(IceVisionPathsDataSource): + def __init__(self, parser: Optional[Type["Parser"]] = None): + super().__init__() + self.parser = parser + + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root, ann_file = data + + if self.parser is not None: + parser = self.parser(ann_file, root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + else: + raise ValueError("The parser type must be provided") + + +class IceDataParserDataSource(IceVisionPathsDataSource): + def __init__(self, parser: Optional[Callable] = None): + super().__init__() + self.parser = parser + + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + root = data + + if self.parser is not None: + parser = self.parser(root) + dataset.num_classes = len(parser.class_map) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] + else: + raise ValueError("The parser must be provided") diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py new file mode 100644 index 0000000000..3d347c730c --- /dev/null +++ b/flash/core/integrations/icevision/transforms.py @@ -0,0 +1,198 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Tuple + +from torch import nn + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires_extras + +if _ICEVISION_AVAILABLE: + from icevision.core import tasks + from icevision.core.bbox import BBox + from icevision.core.keypoints import KeyPoints + from icevision.core.mask import EncodedRLEs, MaskArray + from icevision.core.record import BaseRecord + from icevision.core.record_components import ( + BBoxesRecordComponent, + ClassMapRecordComponent, + FilepathRecordComponent, + ImageRecordComponent, + InstancesLabelsRecordComponent, + KeyPointsRecordComponent, + MasksRecordComponent, + RecordIDRecordComponent, + ) + from icevision.tfms import A + + +def to_icevision_record(sample: Dict[str, Any]): + record = BaseRecord([]) + + metadata = sample.get(DefaultDataKeys.METADATA, None) or {} + + if "image_id" in metadata: + record_id_component = RecordIDRecordComponent() + record_id_component.set_record_id(metadata["image_id"]) + + component = ClassMapRecordComponent(tasks.detection) + component.set_class_map(metadata.get("class_map", None)) + record.add_component(component) + + if "labels" in sample[DefaultDataKeys.TARGET]: + labels_component = InstancesLabelsRecordComponent() + labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"]) + record.add_component(labels_component) + + if "bboxes" in sample[DefaultDataKeys.TARGET]: + bboxes = [ + BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"]) + for bbox in sample[DefaultDataKeys.TARGET]["bboxes"] + ] + component = BBoxesRecordComponent() + component.set_bboxes(bboxes) + record.add_component(component) + + if "masks" in sample[DefaultDataKeys.TARGET]: + mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"]) + component = MasksRecordComponent() + component.set_masks(mask_array) + record.add_component(component) + + if "keypoints" in sample[DefaultDataKeys.TARGET]: + keypoints = [] + + for keypoints_list, keypoints_metadata in zip( + sample[DefaultDataKeys.TARGET]["keypoints"], sample[DefaultDataKeys.TARGET]["keypoints_metadata"] + ): + xyv = [] + for keypoint in keypoints_list: + xyv.extend((keypoint["x"], keypoint["y"], keypoint["visible"])) + + keypoints.append(KeyPoints.from_xyv(xyv, keypoints_metadata)) + component = KeyPointsRecordComponent() + component.set_keypoints(keypoints) + record.add_component(component) + + if isinstance(sample[DefaultDataKeys.INPUT], str): + input_component = FilepathRecordComponent() + input_component.set_filepath(sample[DefaultDataKeys.INPUT]) + else: + if "filepath" in metadata: + input_component = FilepathRecordComponent() + input_component.filepath = metadata["filepath"] + else: + input_component = ImageRecordComponent() + input_component.composite = record + input_component.set_img(sample[DefaultDataKeys.INPUT]) + record.add_component(input_component) + + return record + + +def from_icevision_record(record: "BaseRecord"): + sample = { + DefaultDataKeys.METADATA: { + "image_id": record.record_id, + } + } + + if record.img is not None: + sample[DefaultDataKeys.INPUT] = record.img + filepath = getattr(record, "filepath", None) + if filepath is not None: + sample[DefaultDataKeys.METADATA]["filepath"] = filepath + elif record.filepath is not None: + sample[DefaultDataKeys.INPUT] = record.filepath + + sample[DefaultDataKeys.TARGET] = {} + + if hasattr(record.detection, "bboxes"): + sample[DefaultDataKeys.TARGET]["bboxes"] = [] + for bbox in record.detection.bboxes: + bbox_list = list(bbox.xywh) + bbox_dict = { + "xmin": bbox_list[0], + "ymin": bbox_list[1], + "width": bbox_list[2], + "height": bbox_list[3], + } + sample[DefaultDataKeys.TARGET]["bboxes"].append(bbox_dict) + + if hasattr(record.detection, "masks"): + masks = record.detection.masks + + if isinstance(masks, EncodedRLEs): + masks = masks.to_mask(record.height, record.width) + + if isinstance(masks, MaskArray): + sample[DefaultDataKeys.TARGET]["masks"] = masks.data + else: + raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.") + + if hasattr(record.detection, "keypoints"): + keypoints = record.detection.keypoints + + sample[DefaultDataKeys.TARGET]["keypoints"] = [] + sample[DefaultDataKeys.TARGET]["keypoints_metadata"] = [] + + for keypoint in keypoints: + keypoints_list = [] + for x, y, v in keypoint.xyv: + keypoints_list.append( + { + "x": x, + "y": y, + "visible": v, + } + ) + sample[DefaultDataKeys.TARGET]["keypoints"].append(keypoints_list) + + # TODO: Unpack keypoints_metadata + sample[DefaultDataKeys.TARGET]["keypoints_metadata"].append(keypoint.metadata) + + if getattr(record.detection, "label_ids", None) is not None: + sample[DefaultDataKeys.TARGET]["labels"] = list(record.detection.label_ids) + + if getattr(record.detection, "class_map", None) is not None: + sample[DefaultDataKeys.METADATA]["class_map"] = record.detection.class_map + + return sample + + +class IceVisionTransformAdapter(nn.Module): + def __init__(self, transform): + super().__init__() + self.transform = A.Adapter(transform) + + def forward(self, x): + record = to_icevision_record(x) + record = self.transform(record) + return from_icevision_record(record) + + +@requires_extras("image") +def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms from IceVision.""" + return { + "pre_tensor_transform": IceVisionTransformAdapter([*A.resize_and_pad(image_size), A.Normalize()]), + } + + +@requires_extras("image") +def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default augmentations from IceVision.""" + return { + "pre_tensor_transform": IceVisionTransformAdapter([*A.aug_tfms(size=image_size), A.Normalize()]), + } diff --git a/flash/core/model.py b/flash/core/model.py index 2c4c2b6ada..7e4d62441b 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import inspect +import pickle from abc import ABCMeta from copy import deepcopy from importlib import import_module @@ -21,15 +22,18 @@ import pytorch_lightning as pl import torch import torchmetrics -from pytorch_lightning import LightningModule +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader, Sampler import flash +from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource from flash.core.data.process import ( @@ -40,19 +44,186 @@ Serializer, SerializerMapping, ) +from flash.core.data.properties import ProcessState from flash.core.registry import FlashRegistry from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.serve import Composition from flash.core.utilities.apply_func import get_callable_dict -from flash.core.utilities.imports import _requires_extras +from flash.core.utilities.imports import requires_extras -class BenchmarkConvergenceCI(Callback): +class ModuleWrapperBase: + """The ``ModuleWrapperBase`` is a base for classes which wrap a ``LightningModule`` or an instance of + ``ModuleWrapperBase``. + + This class ensures that trainer attributes are forwarded to any wrapped or nested + ``LightningModule`` instances so that nested calls to ``.log`` are handled correctly. The ``ModuleWrapperBase`` is + also stateful, meaning that a :class:`~flash.core.data.data_pipeline.DataPipelineState` can be attached. Attached + state will be forwarded to any nested ``ModuleWrapperBase`` instances. + """ + def __init__(self): + super().__init__() + + self._children = [] + + # TODO: create enum values to define what are the exact states + self._data_pipeline_state: Optional[DataPipelineState] = None + + # model own internal state shared with the data pipeline. + self._state: Dict[Type[ProcessState], ProcessState] = {} + + def __setattr__(self, key, value): + if isinstance(value, (LightningModule, ModuleWrapperBase)): + self._children.append(key) + patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results", "_data_pipeline_state"] + if isinstance(value, Trainer) or key in patched_attributes: + if hasattr(self, "_children"): + for child in self._children: + setattr(getattr(self, child), key, value) + super().__setattr__(key, value) + + def get_state(self, state_type): + if state_type in self._state: + return self._state[state_type] + if self._data_pipeline_state is not None: + return self._data_pipeline_state.get_state(state_type) + return None + + def set_state(self, state: ProcessState): + self._state[type(state)] = state + if self._data_pipeline_state is not None: + self._data_pipeline_state.set_state(state) + + def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): + for state in self._state.values(): + data_pipeline_state.set_state(state) + for child in self._children: + child = getattr(self, child) + if hasattr(child, "attach_data_pipeline_state"): + child.attach_data_pipeline_state(data_pipeline_state) + + +class DatasetProcessor: + """The ``DatasetProcessor`` mixin provides hooks for classes which need custom logic for producing the data + loaders for each running stage given the corresponding dataset.""" + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + collate_fn=collate_fn, + ) + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = None, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + +class BenchmarkConvergenceCI(Callback): def __init__(self): self.history = [] - def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.history.append(deepcopy(trainer.callback_metrics)) if trainer.current_epoch == trainer.max_epochs - 1: fn = getattr(pl_module, "_ci_benchmark_fn", None) @@ -63,10 +234,8 @@ def on_validation_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModul def predict_context(func: Callable) -> Callable: - """ - This decorator is used as context manager - to put model in eval mode before running predict and reset to train after. - """ + """This decorator is used as context manager to put model in eval mode before running predict and reset to + train after.""" @functools.wraps(func) def wrapper(self, *args, **kwargs) -> Any: @@ -86,20 +255,19 @@ def wrapper(self, *args, **kwargs) -> Any: class CheckDependenciesMeta(ABCMeta): - def __new__(mcs, *args, **kwargs): result = ABCMeta.__new__(mcs, *args, **kwargs) if result.required_extras is not None: - result.__init__ = _requires_extras(result.required_extras)(result.__init__) + result.__init__ = requires_extras(result.required_extras)(result.__init__) load_from_checkpoint = getattr(result, "load_from_checkpoint", None) if load_from_checkpoint is not None: result.load_from_checkpoint = classmethod( - _requires_extras(result.required_extras)(result.load_from_checkpoint.__func__) + requires_extras(result.required_extras)(result.load_from_checkpoint.__func__) ) return result -class Task(LightningModule, metaclass=CheckDependenciesMeta): +class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, metaclass=CheckDependenciesMeta): """A general Task. Args: @@ -140,7 +308,8 @@ def __init__( self.optimizer_kwargs = optimizer_kwargs or {} self.scheduler_kwargs = scheduler_kwargs or {} - self.metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) + self.train_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(metrics)) + self.val_metrics = nn.ModuleDict({} if metrics is None else get_callable_dict(deepcopy(metrics))) self.learning_rate = learning_rate # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") @@ -150,39 +319,54 @@ def __init__( self._postprocess: Optional[Postprocess] = postprocess self._serializer: Optional[Serializer] = None - # TODO: create enum values to define what are the exact states - self._data_pipeline_state: Optional[DataPipelineState] = None - # Explicitly set the serializer to call the setter self.deserializer = deserializer self.serializer = serializer - def step(self, batch: Any, batch_idx: int) -> Any: - """ - The training/validation/test step. Override for custom behavior. + def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: + """The training/validation/test step. + + Override for custom behavior. """ x, y = batch y_hat = self(x) + y, y_hat = self.apply_filtering(y, y_hat) output = {"y_hat": y_hat} y_hat = self.to_loss_format(output["y_hat"]) losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} y_hat = self.to_metrics_format(output["y_hat"]) - for name, metric in self.metrics.items(): + + logs = {} + + for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) - logs.update(losses) + if len(losses.values()) > 1: logs["total_loss"] = sum(losses.values()) return logs["total_loss"], logs - output["loss"] = list(losses.values())[0] - output["logs"] = logs + + output["loss"] = self.compute_loss(losses) + output["logs"] = self.compute_logs(logs, losses) output["y"] = y return output + def compute_loss(self, losses: Dict[str, torch.Tensor]) -> torch.Tensor: + return list(losses.values())[0] + + def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]): + logs.update(losses) + return logs + + @staticmethod + def apply_filtering(y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """This function is used to filter some labels or predictions which aren't conform.""" + return y, y_hat + @staticmethod def to_loss_format(x: torch.Tensor) -> torch.Tensor: return x @@ -195,16 +379,16 @@ def forward(self, x: Any) -> Any: return self.model(x) def training_step(self, batch: Any, batch_idx: int) -> Any: - output = self.step(batch, batch_idx) + output = self.step(batch, batch_idx, self.train_metrics) self.log_dict({f"train_{k}": v for k, v in output["logs"].items()}, on_step=True, on_epoch=True, prog_bar=True) return output["loss"] def validation_step(self, batch: Any, batch_idx: int) -> None: - output = self.step(batch, batch_idx) + output = self.step(batch, batch_idx, self.val_metrics) self.log_dict({f"val_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Any, batch_idx: int) -> None: - output = self.step(batch, batch_idx) + output = self.step(batch, batch_idx, self.val_metrics) self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) @predict_context @@ -215,8 +399,7 @@ def predict( deserializer: Optional[Deserializer] = None, data_pipeline: Optional[DataPipeline] = None, ) -> Any: - """ - Predict function for raw data or processed data + """Predict function for raw data or processed data. Args: x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. @@ -229,8 +412,10 @@ def predict( running_stage = RunningStage.PREDICTING data_pipeline = self.build_data_pipeline(data_source or "default", deserializer, data_pipeline) - x = list(data_pipeline.data_source.generate_dataset(x, running_stage)) - x = data_pipeline.worker_preprocessor(running_stage)(x) + dataset = data_pipeline.data_source.generate_dataset(x, running_stage) + dataloader = self.process_predict_dataset(dataset) + x = list(dataloader.dataset) + x = data_pipeline.worker_preprocessor(running_stage, collate_fn=dataloader.collate_fn)(x) # todo (tchaton): Remove this when sync with Lightning master. if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3: x = self.transfer_batch_to_device(x, self.device, 0) @@ -322,8 +507,11 @@ def deserializer(self, deserializer: Union[Deserializer, Mapping[str, Deserializ @torch.jit.unused @property def serializer(self) -> Optional[Serializer]: - """The current :class:`.Serializer` associated with this model. If this property was set to a mapping - (e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`.""" + """The current :class:`.Serializer` associated with this model. + + If this property was set to a mapping + (e.g. ``.serializer = {'output1': SerializerOne()}``) then this will be a :class:`.MappingSerializer`. + """ return self._serializer @torch.jit.unused @@ -358,21 +546,23 @@ def build_data_pipeline( deserializer, old_data_source, preprocess, postprocess, serializer = None, None, None, None, None # Datamodule - if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None: - old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None) - preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None) - postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None) - serializer = getattr(self.datamodule.data_pipeline, '_serializer', None) - deserializer = getattr(self.datamodule.data_pipeline, '_deserializer', None) - - elif self.trainer is not None and hasattr(self.trainer, 'datamodule') and getattr( - self.trainer.datamodule, 'data_pipeline', None - ) is not None: - old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None) - preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None) - postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None) - serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None) - deserializer = getattr(self.trainer.datamodule.data_pipeline, '_deserializer', None) + if self.datamodule is not None and getattr(self.datamodule, "data_pipeline", None) is not None: + old_data_source = getattr(self.datamodule.data_pipeline, "data_source", None) + preprocess = getattr(self.datamodule.data_pipeline, "_preprocess_pipeline", None) + postprocess = getattr(self.datamodule.data_pipeline, "_postprocess_pipeline", None) + serializer = getattr(self.datamodule.data_pipeline, "_serializer", None) + deserializer = getattr(self.datamodule.data_pipeline, "_deserializer", None) + + elif ( + self.trainer is not None + and hasattr(self.trainer, "datamodule") + and getattr(self.trainer.datamodule, "data_pipeline", None) is not None + ): + old_data_source = getattr(self.trainer.datamodule.data_pipeline, "data_source", None) + preprocess = getattr(self.trainer.datamodule.data_pipeline, "_preprocess_pipeline", None) + postprocess = getattr(self.trainer.datamodule.data_pipeline, "_postprocess_pipeline", None) + serializer = getattr(self.trainer.datamodule.data_pipeline, "_serializer", None) + deserializer = getattr(self.trainer.datamodule.data_pipeline, "_deserializer", None) else: # TODO: we should log with low severity level that we use defaults to create # `preprocess`, `postprocess` and `serializer`. @@ -397,10 +587,10 @@ def build_data_pipeline( preprocess, postprocess, serializer, - getattr(data_pipeline, '_deserializer', None), - getattr(data_pipeline, '_preprocess_pipeline', None), - getattr(data_pipeline, '_postprocess_pipeline', None), - getattr(data_pipeline, '_serializer', None), + getattr(data_pipeline, "_deserializer", None), + getattr(data_pipeline, "_preprocess_pipeline", None), + getattr(data_pipeline, "_postprocess_pipeline", None), + getattr(data_pipeline, "_serializer", None), ) data_source = data_source or old_data_source @@ -415,6 +605,8 @@ def build_data_pipeline( deserializer = getattr(preprocess, "deserializer", deserializer) data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer) + self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() + self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline @@ -426,8 +618,11 @@ def is_servable(self) -> bool: @torch.jit.unused @property def data_pipeline(self) -> DataPipeline: - """The current :class:`.DataPipeline`. If set, the new value will override the :class:`.Task` defaults. See - :py:meth:`~build_data_pipeline` for more details on the resolution order.""" + """The current :class:`.DataPipeline`. + + If set, the new value will override the :class:`.Task` defaults. See + :py:meth:`~build_data_pipeline` for more details on the resolution order. + """ return self.build_data_pipeline() @torch.jit.unused @@ -438,11 +633,12 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: self._preprocess, self._postprocess, self._serializer, - getattr(data_pipeline, '_deserializer', None), - getattr(data_pipeline, '_preprocess_pipeline', None), - getattr(data_pipeline, '_postprocess_pipeline', None), - getattr(data_pipeline, '_serializer', None), + getattr(data_pipeline, "_deserializer", None), + getattr(data_pipeline, "_preprocess_pipeline", None), + getattr(data_pipeline, "_postprocess_pipeline", None), + getattr(data_pipeline, "_serializer", None), ) + # self._preprocess.state_dict() if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None): self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore @@ -450,12 +646,12 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: @torch.jit.unused @property def preprocess(self) -> Preprocess: - return getattr(self.data_pipeline, '_preprocess_pipeline', None) + return getattr(self.data_pipeline, "_preprocess_pipeline", None) @torch.jit.unused @property def postprocess(self) -> Postprocess: - return getattr(self.data_pipeline, '_postprocess_pipeline', None) + return getattr(self.data_pipeline, "_postprocess_pipeline", None) def on_train_dataloader(self) -> None: if self.data_pipeline is not None: @@ -494,25 +690,45 @@ def on_fit_end(self) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # This may be an issue since here we create the same problems with pickle as in # https://pytorch.org/docs/stable/notes/serialization.html - if self.data_pipeline is not None and 'data_pipeline' not in checkpoint: - checkpoint['data_pipeline'] = self.data_pipeline - if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint: - checkpoint['_data_pipeline_state'] = self._data_pipeline_state + if self.data_pipeline is not None and "data_pipeline" not in checkpoint: + try: + pickle.dumps(self.data_pipeline) # TODO: DataPipeline not always pickleable + checkpoint["data_pipeline"] = self.data_pipeline + except AttributeError: + rank_zero_warn("DataPipeline couldn't be added to the checkpoint.") + if self._data_pipeline_state is not None and "_data_pipeline_state" not in checkpoint: + checkpoint["_data_pipeline_state"] = self._data_pipeline_state super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: super().on_load_checkpoint(checkpoint) - if 'data_pipeline' in checkpoint: - self.data_pipeline = checkpoint['data_pipeline'] - if '_data_pipeline_state' in checkpoint: - self._data_pipeline_state = checkpoint['_data_pipeline_state'] + if "data_pipeline" in checkpoint: + self.data_pipeline = checkpoint["data_pipeline"] + if "_data_pipeline_state" in checkpoint: + self._data_pipeline_state = checkpoint["_data_pipeline_state"] @classmethod - def available_backbones(cls) -> List[str]: - registry: Optional[FlashRegistry] = getattr(cls, "backbones", None) - if registry is None: - return [] - return registry.available_keys() + def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List[str]], List[str]]: + if head is None: + registry: Optional[FlashRegistry] = getattr(cls, "backbones", None) + if registry is not None: + return registry.available_keys() + heads = cls.available_heads() + else: + heads = [head] + + result = {} + for head in heads: + metadata = cls.heads.get(head, with_metadata=True)["metadata"] + if "backbones" in metadata: + backbones = metadata["backbones"].available_keys() + else: + backbones = cls.available_backbones() + result[head] = backbones + + if len(result) == 1: + result = next(iter(result.values())) + return result @classmethod def available_heads(cls) -> List[str]: @@ -592,14 +808,13 @@ def _instantiate_scheduler(self, optimizer: Optimizer) -> _LRScheduler: def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): - if 'preprocess.state_dict' in state_dict: + if "preprocess.state_dict" in state_dict: try: preprocess_state_dict = state_dict["preprocess.state_dict"] meta = preprocess_state_dict["_meta"] cls = getattr(import_module(meta["module"]), meta["class_name"]) self._preprocess = cls.load_state_dict( - {k: v - for k, v in preprocess_state_dict.items() if k != '_meta'}, + {k: v for k, v in preprocess_state_dict.items() if k != "_meta"}, strict=strict, ) self._preprocess._state = meta["_state"] @@ -620,7 +835,7 @@ def configure_callbacks(self): if flash._IS_TESTING and torch.cuda.is_available(): return [BenchmarkConvergenceCI()] - @_requires_extras("serve") + @requires_extras("serve") def run_serve_sanity_check(self): if not self.is_servable: raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") @@ -640,8 +855,8 @@ def run_serve_sanity_check(self): resp = tc.post("http://0.0.0.0:8000/predict", json=body) print(f"Sanity check response: {resp.json()}") - @_requires_extras("serve") - def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> 'Composition': + @requires_extras("serve") + def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True) -> "Composition": if not self.is_servable: raise NotImplementedError("This Task is not servable. Attach a Deserializer to enable serving.") diff --git a/flash/core/optimizers/__init__.py b/flash/core/optimizers/__init__.py new file mode 100644 index 0000000000..76b1ef8a3e --- /dev/null +++ b/flash/core/optimizers/__init__.py @@ -0,0 +1,3 @@ +from flash.core.optimizers.lamb import LAMB # noqa: F401 +from flash.core.optimizers.lars import LARS # noqa: F401 +from flash.core.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401 diff --git a/flash/core/optimizers/lamb.py b/flash/core/optimizers/lamb.py new file mode 100644 index 0000000000..a70293baa5 --- /dev/null +++ b/flash/core/optimizers/lamb.py @@ -0,0 +1,167 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# Implemented by @ananyahjha93 +# also found at: https://github.com/gridai-labs/aavae/tree/main/src/optimizers +# References: +# - https://arxiv.org/pdf/1904.00962.pdf +# - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/adam.py +import math +from typing import Tuple + +import torch +from torch import nn +from torch.optim.optimizer import Optimizer + + +class LAMB(Optimizer): + r"""Extends ADAM in pytorch to incorporate LAMB algorithm from the paper: + `Large batch optimization for deep learning: Training BERT in 76 minutes `_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + exclude_from_layer_adaptation (bool, optional): layers which do not need LAMB + layer adaptation (default: False) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond `_ + (default: False) + + Example: + >>> model = nn.Linear(10, 1) + >>> optimizer = LAMB(model.parameters(), lr=0.1) + >>> optimizer.zero_grad() + >>> # loss_fn(model(input), target).backward() + >>> optimizer.step() + + .. warning:: + Since the default weight decay for LAMB is set to 0., we do not club together + 0. weight decay and exclusion from layer adaptation like LARS. This would cause + the optimizer to exclude all layers from layer adaptation. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0, + exclude_from_layer_adaptation: bool = False, + amsgrad: bool = False, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + exclude_from_layer_adaptation=exclude_from_layer_adaptation, + amsgrad=amsgrad, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("LAMB does not support sparse gradients") + amsgrad = group["amsgrad"] + exclude_from_layer_adaptation = group["exclude_from_layer_adaptation"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"]) + + numerator = exp_avg / bias_correction1 + update = numerator / denom + + if group["weight_decay"] != 0: + update = update.add(p.data, alpha=group["weight_decay"]) + + trust_ratio = 1.0 + if not exclude_from_layer_adaptation: + w_norm = torch.norm(p.data) + g_norm = torch.norm(update) + + if w_norm > 0 and g_norm > 0: + trust_ratio = w_norm / g_norm + + p.add_(update, alpha=-group["lr"] * trust_ratio) + + return loss diff --git a/flash/core/optimizers/lars.py b/flash/core/optimizers/lars.py new file mode 100644 index 0000000000..f43f7893ee --- /dev/null +++ b/flash/core/optimizers/lars.py @@ -0,0 +1,156 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# Implemented by @ananyahjha93 +# also found at: https://github.com/gridai-labs/aavae/tree/main/src/optimizers +# References: +# - https://arxiv.org/pdf/1708.03888.pdf +# - https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py +import torch +from torch import nn +from torch.optim.optimizer import Optimizer, required + + +class LARS(Optimizer): + r"""Extends SGD in PyTorch with LARS scaling from the paper + `Large batch training of Convolutional Networks `_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001) + eps (float, optional): eps for division denominator (default: 1e-8) + + Example: + >>> model = nn.Linear(10, 1) + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> # loss_fn(model(input), target).backward() + >>> optimizer.step() + + .. note:: + The application of momentum in the SGD part is modified according to + the PyTorch standards. LARS scaling fits into the equation in the + following fashion. + + .. math:: + \begin{aligned} + g_{t+1} & = \text{lars\_lr} * (\beta * p_{t} + g_{t+1}), \\ + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + + where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`\beta` denote the + parameters, gradient, velocity, momentum, and weight decay respectively. + The :math:`lars_lr` is defined by Eq. 6 in the paper. + The Nesterov version is analogously modified. + + .. warning:: + Parameters with weight decay set to 0 will automatically be excluded from + layer-wise LR scaling. This is to ensure consistency with papers like SimCLR + and BYOL. + """ + + def __init__( + self, + params, + lr=required, + momentum: float = 0, + dampening: float = 0, + weight_decay: float = 0, + nesterov: bool = False, + trust_coefficient: float = 0.001, + eps: float = 1e-8, + ): + if lr is not required and lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + self.eps = eps + self.trust_coefficient = trust_coefficient + + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # exclude scaling for params with 0 weight decay + for group in self.param_groups: + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + for p in group["params"]: + if p.grad is None: + continue + + d_p = p.grad + p_norm = torch.norm(p.data) + g_norm = torch.norm(p.grad.data) + + # lars scaling + weight decay part + if weight_decay != 0: + if p_norm != 0 and g_norm != 0: + lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps) + lars_lr *= self.trust_coefficient + + d_p = d_p.add(p, alpha=weight_decay) + d_p *= lars_lr + + # sgd part + if momentum != 0: + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.clone(d_p).detach() + else: + buf = param_state["momentum_buffer"] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + p.add_(d_p, alpha=-group["lr"]) + + return loss diff --git a/flash/core/optimizers/lr_scheduler.py b/flash/core/optimizers/lr_scheduler.py new file mode 100644 index 0000000000..187f6c495f --- /dev/null +++ b/flash/core/optimizers/lr_scheduler.py @@ -0,0 +1,138 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# Implemented by @ananyahjha93 +# also found at: https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py +import math +import warnings +from typing import List + +from torch import nn +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr + and base_lr followed by a cosine annealing schedule between base_lr and eta_min. + + .. warning:: + It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` + after each iteration as calling it after each epoch will keep the starting lr at + warmup_start_lr for the first epoch which is 0 in most cases. + + .. warning:: + passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. + It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of + :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing + epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling + train and validation methods. + + Example: + >>> layer = nn.Linear(10, 1) + >>> optimizer = Adam(layer.parameters(), lr=0.02) + >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) + >>> # + >>> # the default case + >>> for epoch in range(40): + ... # train(...) + ... # validate(...) + ... scheduler.step() + >>> # + >>> # passing epoch param case + >>> for epoch in range(40): + ... scheduler.step(epoch) + ... # train(...) + ... # validate(...) + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_epochs: int, + max_epochs: int, + warmup_start_lr: float = 0.0, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + """Compute learning rate using chainable form of the scheduler.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", + UserWarning, + ) + + if self.last_epoch == self.warmup_epochs: + return self.base_lrs + elif self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + elif self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + / ( + 1 + + math.cos( + math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> List[float]: + """Called when epoch is passed as a param to the `step` function of the scheduler.""" + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + + self.last_epoch * (base_lr - self.warmup_start_lr) / max(1, self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + + 0.5 + * (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] diff --git a/flash/core/registry.py b/flash/core/registry.py index 5763e01ab0..d5b1b1d764 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -11,21 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from types import FunctionType +import functools from typing import Any, Callable, Dict, List, Optional, Union from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException +from flash.core.utilities.providers import Provider + _REGISTERED_FUNCTION = Dict[str, Any] -class FlashRegistry: - """ - This class is used to register function or ``functools.partial`` class to a registry. +def print_provider_info(name, providers, func): + if not isinstance(providers, List): + providers = [providers] + providers = list(providers) + if len(providers) > 1: + providers[-2] = f"{str(providers[-2])} and {str(providers[-1])}" + providers = providers[:-1] + message = f"Using '{name}' provided by {', '.join(str(provider) for provider in providers)}." + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank_zero_info(message) + return func(*args, **kwargs) - """ + return wrapper + + +class FlashRegistry: + """This class is used to register function or :class:`functools.partial` class to a registry.""" def __init__(self, name: str, verbose: bool = False) -> None: self.name = name @@ -39,7 +54,7 @@ def __contains__(self, key) -> bool: return any(key == e["name"] for e in self.functions) def __repr__(self) -> str: - return f'{self.__class__.__name__}(name={self.name}, functions={self.functions})' + return f"{self.__class__.__name__}(name={self.name}, functions={self.functions})" def get( self, @@ -48,8 +63,7 @@ def get( strict: bool = True, **metadata, ) -> Union[Callable, _REGISTERED_FUNCTION, List[_REGISTERED_FUNCTION], List[Callable]]: - """ - This function is used to gather matches from the registry: + """This function is used to gather matches from the registry: Args: key: Name of the registered function. @@ -59,7 +73,7 @@ def get( """ matches = [e for e in self.functions if key == e["name"]] if not matches: - raise KeyError(f"Key: {key} is not in {repr(self)}") + raise KeyError(f"Key: {key} is not in {type(self).__name__}") if metadata: matches = [m for m in matches if metadata.items() <= m["metadata"].items()] @@ -77,16 +91,20 @@ def _register_function( fn: Callable, name: Optional[str] = None, override: bool = False, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ): - if not isinstance(fn, FunctionType) and not isinstance(fn, partial): - raise MisconfigurationException(f"You can only register a function, found: {fn}") + if not callable(fn): + raise MisconfigurationException(f"You can only register a callable, found: {fn}") name = name or fn.__name__ if self._verbose: rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}") + if "providers" in metadata: + providers = metadata["providers"] + fn = print_provider_info(name, providers, fn) + item = {"fn": fn, "name": name, "metadata": metadata or {}} matching_index = self._find_matching_index(item) @@ -110,21 +128,23 @@ def __call__( fn: Optional[Callable[..., Any]] = None, name: Optional[str] = None, override: bool = False, - **metadata + providers: Optional[Union[Provider, List[Provider]]] = None, + **metadata, ) -> Callable: - """ - This function is used to register new functions to the registry along their metadata. + """This function is used to register new functions to the registry along their metadata. Functions can be filtered using metadata using the ``get`` function. - """ + if providers is not None: + metadata["providers"] = providers + if fn is not None: self._register_function(fn=fn, name=name, override=override, metadata=metadata) return fn # raise the error ahead of time if not (name is None or isinstance(name, str)): - raise TypeError(f'`name` must be a str, found {name}') + raise TypeError(f"`name` must be a str, found {name}") def _register(cls): self._register_function(fn=cls, name=name, override=override, metadata=metadata) diff --git a/flash/core/schedulers.py b/flash/core/schedulers.py index 4e01306b2a..bfc1bc82b8 100644 --- a/flash/core/schedulers.py +++ b/flash/core/schedulers.py @@ -7,8 +7,9 @@ if _TRANSFORMERS_AVAILABLE: from transformers import optimization + functions: List[Callable] = [ - getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') + getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != "get_scheduler") ] for fn in functions: _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:]) diff --git a/flash/core/serve/_compat/__init__.py b/flash/core/serve/_compat/__init__.py index 439ab3add0..50af1bf725 100644 --- a/flash/core/serve/_compat/__init__.py +++ b/flash/core/serve/_compat/__init__.py @@ -1,3 +1,3 @@ from flash.core.serve._compat.cached_property import cached_property -__all__ = ("cached_property", ) +__all__ = ("cached_property",) diff --git a/flash/core/serve/_compat/cached_property.py b/flash/core/serve/_compat/cached_property.py index a2fa77def5..2adde68103 100644 --- a/flash/core/serve/_compat/cached_property.py +++ b/flash/core/serve/_compat/cached_property.py @@ -5,7 +5,7 @@ credits: https://github.com/penguinolog/backports.cached_property """ -__all__ = ("cached_property", ) +__all__ = ("cached_property",) # Standard Library from sys import version_info @@ -26,11 +26,9 @@ class cached_property: # NOSONAR # pylint: disable=invalid-name # noqa: N801 """Cached property implementation. - Transform a method of a class into a property whose value is computed once - and then cached as a normal attribute for the life of the instance. - Similar to property(), with the addition of caching. - Useful for expensive computed properties of instances - that are otherwise effectively immutable. + Transform a method of a class into a property whose value is computed once and then cached as a normal attribute + for the life of the instance. Similar to property(), with the addition of caching. Useful for expensive computed + properties of instances that are otherwise effectively immutable. """ def __init__(self, func: Callable[[Any], _T]) -> None: diff --git a/flash/core/serve/component.py b/flash/core/serve/component.py index d74a5a15b7..e528a64750 100644 --- a/flash/core/serve/component.py +++ b/flash/core/serve/component.py @@ -7,7 +7,7 @@ from flash.core.serve.core import ParameterContainer, Servable from flash.core.serve.decorators import BoundMeta, UnboundMeta -from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _requires_extras, _SERVE_AVAILABLE +from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, requires_extras if _CYTOOLZ_AVAILABLE: from cytoolz import first, isiterable, valfilter @@ -41,7 +41,7 @@ def _validate_exposed_input_parameters_valid(instance): ) -def _validate_subclass_init_signature(cls: Type['ModelComponent']): +def _validate_subclass_init_signature(cls: Type["ModelComponent"]): """Raises SyntaxError if the __init__ method is not formatted correctly. Expects arguments: ['self', 'models', Optional['config']] @@ -76,7 +76,7 @@ class to perform the analysis on def _validate_model_args( args: Union[_ServableType, List[_ServableType], Tuple[_ServableType, ...], Dict[str, _ServableType]] ) -> None: - """Validator for machine learning models + """Validator for machine learning models. Parameters ---------- @@ -94,19 +94,19 @@ def _validate_model_args( raise ValueError(f"Iterable args={args} must have length >= 1") if isinstance(args, (list, tuple)): - if not all((isinstance(x, _Servable_t) for x in args)): + if not all(isinstance(x, _Servable_t) for x in args): raise TypeError(f"One of arg in args={args} is not type {_Servable_t}") elif isinstance(args, dict): - if not all((isinstance(x, str) for x in args.keys())): + if not all(isinstance(x, str) for x in args.keys()): raise TypeError(f"One of keys in args={args.keys()} is not type {str}") - if not all((isinstance(x, _Servable_t) for x in args.values())): + if not all(isinstance(x, _Servable_t) for x in args.values()): raise TypeError(f"One of values in args={args} is not type {_Servable_t}") elif not isinstance(args, _Servable_t): raise TypeError(f"Args must be instance, list/tuple, or mapping of {_Servable_t}") def _validate_config_args(config: Optional[Dict[str, Union[str, int, float, bytes]]]) -> None: - """Validator for the configuration + """Validator for the configuration. Parameters ---------- @@ -143,11 +143,9 @@ def _validate_config_args(config: Optional[Dict[str, Union[str, int, float, byte class FlashServeMeta(type): - """ - We keep a mapping of externally used names to classes. - """ + """We keep a mapping of externally used names to classes.""" - @_requires_extras("serve") + @requires_extras("serve") def __new__(cls, name, bases, namespace): # create new instance of cls in order to apply any @expose class decorations. _tmp_cls = super().__new__(cls, name, bases, namespace) @@ -165,7 +163,9 @@ def __new__(cls, name, bases, namespace): # alter namespace to insert flash serve info as bound components of class. exposed = first(ex_meths.values()) namespace["_flashserve_meta_"] = exposed.flashserve_meta - namespace["__call__"] = wraps(exposed)(exposed, ) + namespace["__call__"] = wraps(exposed)( + exposed, + ) new_cls = super().__new__(cls, name, bases, namespace) if new_cls.__name__ != "ModelComponent": @@ -181,8 +181,8 @@ def __new__(cls, name, bases, namespace): def __call__(cls, *args, **kwargs): """Customize steps taken during class creation / initalization. - super().__call__() within metaclass means: return instance - created by calling metaclass __prepare__ -> __new__ -> __init__ + super().__call__() within metaclass means: return instance created by calling metaclass __prepare__ -> __new__ + -> __init__ """ klass = super().__call__(*args, **kwargs) klass._flashserve_meta_ = replace(klass._flashserve_meta_) @@ -210,7 +210,7 @@ class ModelComponent(metaclass=FlashServeMeta): _flashserve_meta_: Optional[Union[BoundMeta, UnboundMeta]] = None def __flashserve_init__(self, models, *, config=None): - """Do a bunch of setup + """Do a bunch of setup. instance's __flashserve_init__ calls subclass __init__ in turn. """ @@ -245,5 +245,6 @@ def outputs(self) -> ParameterContainer: def uid(self) -> str: return self._flashserve_meta_.uid + else: ModelComponent = object diff --git a/flash/core/serve/composition.py b/flash/core/serve/composition.py index 5a6642cb4a..f3f9e8441e 100644 --- a/flash/core/serve/composition.py +++ b/flash/core/serve/composition.py @@ -14,8 +14,9 @@ concat, first = None, None -def _parse_composition_kwargs(**kwargs: Union[ModelComponent, - Endpoint]) -> Tuple[Dict[str, ModelComponent], Dict[str, Endpoint]]: +def _parse_composition_kwargs( + **kwargs: Union[ModelComponent, Endpoint] +) -> Tuple[Dict[str, ModelComponent], Dict[str, Endpoint]]: components, endpoints = {}, {} for k, v in kwargs.items(): @@ -28,8 +29,7 @@ def _parse_composition_kwargs(**kwargs: Union[ModelComponent, if len(components) > 1 and len(endpoints) == 0: raise ValueError( - "Must explicitly define atelast one Endpoint when " - "two or more components are included in a composition." + "Must explicitly define atelast one Endpoint when " "two or more components are included in a composition." ) return (components, endpoints) diff --git a/flash/core/serve/core.py b/flash/core/serve/core.py index f88f617184..563c0d580e 100644 --- a/flash/core/serve/core.py +++ b/flash/core/serve/core.py @@ -8,7 +8,7 @@ from flash.core.serve.types.base import BaseType from flash.core.serve.utils import download_file -from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, _requires_extras +from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires_extras if _PYDANTIC_AVAILABLE: from pydantic import FilePath, HttpUrl, parse_obj_as, ValidationError @@ -20,7 +20,7 @@ @dataclass class Endpoint: - """An endpoint maps a route and request/response payload to components + """An endpoint maps a route and request/response payload to components. Parameters ---------- @@ -41,8 +41,7 @@ class Endpoint: def __post_init__(self): if not isinstance(self.route, str): raise TypeError( - f"route parameter must be type={str}, recieved " - f"route={self.route} of type={type(self.route)}" + f"route parameter must be type={str}, recieved " f"route={self.route} of type={type(self.route)}" ) if not self.route.startswith("/"): raise ValueError("route must begin with a `slash` character (ie `/`).") @@ -76,12 +75,15 @@ def __call__(self, *args, **kwargs): return self.instance(*args, **kwargs) -ServableValidArgs_T = Union[Tuple[Type[pl.LightningModule], Union[HttpUrl, FilePath]], Tuple[HttpUrl], - Tuple[FilePath], ] +ServableValidArgs_T = Union[ + Tuple[Type[pl.LightningModule], Union[HttpUrl, FilePath]], + Tuple[HttpUrl], + Tuple[FilePath], +] class Servable: - """Wrapper around a model object to enable serving at scale. + """ModuleWrapperBase around a model object to enable serving at scale. Create a ``Servable`` from either (LM, LOCATION) or (LOCATION,) @@ -100,12 +102,12 @@ class Servable: * How to handle ``__init__`` args not recorded in hparams of ``pl.LightningModule`` """ - @_requires_extras("serve") + @requires_extras("serve") def __init__( self, *args: ServableValidArgs_T, download_path: Optional[Path] = None, - script_loader_cls: Type[FlashServeScriptLoader] = FlashServeScriptLoader + script_loader_cls: Type[FlashServeScriptLoader] = FlashServeScriptLoader, ): try: loc = args[-1] # last element in args is always loc @@ -175,16 +177,13 @@ def _repr_pretty_(self, p, cycle): # pragma: no cover def __str__(self): return ( - f"{self.source_component}.outputs.{self.source_key} >> " - f"{self.target_component}.inputs.{self.target_key}" + f"{self.source_component}.outputs.{self.source_key} >> " f"{self.target_component}.inputs.{self.target_key}" ) @dataclass class Parameter: - """ - Holder class for each grid type of a component and connections from those - to the types of other components. + """Holder class for each grid type of a component and connections from those to the types of other components. Parameters ---------- @@ -208,7 +207,7 @@ def __str__(self): return f"{self.component_uid}.{self.position}.{self.name}" def __terminate_invalid_connection_request(self, other: "Parameter", dunder_meth_called: str) -> None: - """verify that components can be composed + """verify that components can be composed. Parameters ---------- @@ -255,7 +254,7 @@ def __terminate_invalid_connection_request(self, other: "Parameter", dunder_meth ) def __lshift__(self, other: "Parameter"): - """Implements composition connecting Parameter << Parameter""" + """Implements composition connecting Parameter << Parameter.""" self.__terminate_invalid_connection_request(other, "__lshift__") con = Connection( source_component=other.component_uid, @@ -266,7 +265,7 @@ def __lshift__(self, other: "Parameter"): self.connections.append(con) def __rshift__(self, other: "Parameter"): - """Implements composition connecting Parameter >> Parameter""" + """Implements composition connecting Parameter >> Parameter.""" self.__terminate_invalid_connection_request(other, "__rshift__") con = Connection( source_component=self.component_uid, @@ -278,7 +277,6 @@ def __rshift__(self, other: "Parameter"): class DictAttrAccessBase: - def __grid_fields__(self) -> Iterator[str]: for field in dataclasses.fields(self): # noqa F402 yield field.name @@ -324,16 +322,17 @@ def make_parameter_container(data: Dict[str, Parameter]) -> ParameterContainer: ParameterContainer = make_dataclass( "ParameterContainer", dataclass_fields, - bases=(DictAttrAccessBase, ), + bases=(DictAttrAccessBase,), frozen=True, unsafe_hash=True, ) return ParameterContainer(**data) -def make_param_dict(inputs: Dict[str, BaseType], outputs: Dict[str, BaseType], - component_uid: str) -> Tuple[Dict[str, Parameter], Dict[str, Parameter]]: - """Convert exposed input/outputs parameters / dtypes to parameter objects +def make_param_dict( + inputs: Dict[str, BaseType], outputs: Dict[str, BaseType], component_uid: str +) -> Tuple[Dict[str, Parameter], Dict[str, Parameter]]: + """Convert exposed input/outputs parameters / dtypes to parameter objects. Returns ------- diff --git a/flash/core/serve/dag/optimization.py b/flash/core/serve/dag/optimization.py index ee988ee1e4..ea4293798e 100644 --- a/flash/core/serve/dag/optimization.py +++ b/flash/core/serve/dag/optimization.py @@ -53,7 +53,7 @@ def cull(dsk, keys): def default_fused_linear_keys_renamer(keys): - """Create new keys for fused tasks""" + """Create new keys for fused tasks.""" typ = type(keys[0]) if typ is str: names = [key_split(x) for x in keys[:0:-1]] @@ -62,7 +62,7 @@ def default_fused_linear_keys_renamer(keys): if typ is tuple and len(keys[0]) > 0 and isinstance(keys[0][0], str): names = [key_split(x) for x in keys[:0:-1]] names.append(keys[0][0]) - return ("-".join(names), ) + keys[0][1:] + return ("-".join(names),) + keys[0][1:] return None @@ -265,7 +265,7 @@ def inline(dsk, keys=None, inline_constants=True, dependencies=None): def inline_functions(dsk, output, fast_functions=None, inline_constants=False, dependencies=None): - """Inline cheap functions into larger operations + """Inline cheap functions into larger operations. Examples -------- @@ -320,7 +320,7 @@ def unwrap_partial(func): def functions_of(task): - """Set of functions contained within nested task + """Set of functions contained within nested task. Examples -------- @@ -350,9 +350,8 @@ def functions_of(task): def default_fused_keys_renamer(keys, max_fused_key_length=120): """Create new keys for ``fuse`` tasks. - The optional parameter `max_fused_key_length` is used to limit the maximum - string length for each renamed key. If this parameter is set to `None`, - there is no limit. + The optional parameter `max_fused_key_length` is used to limit the maximum string length for each renamed key. If + this parameter is set to `None`, there is no limit. """ it = reversed(keys) first_key = next(it) @@ -382,7 +381,7 @@ def _enforce_max_key_limit(key_name): names = sorted(names) names.append(first_key[0]) concatenated_name = "-".join(names) - return (_enforce_max_key_limit(concatenated_name), ) + first_key[1:] + return (_enforce_max_key_limit(concatenated_name),) + first_key[1:] # PEP-484 compliant singleton constant @@ -553,16 +552,18 @@ def fuse( children_stack_pop() # This is a leaf node in the reduction region # key, task, fused_keys, height, width, number of nodes, fudge, set of edges - info_stack_append(( - child, - rv[child], - [child] if rename_keys else None, - 1, - 1, - 1, - 0, - deps[child] - reducible, - )) + info_stack_append( + ( + child, + rv[child], + [child] if rename_keys else None, + 1, + 1, + 1, + 0, + deps[child] - reducible, + ) + ) else: children_stack_pop() # Calculate metrics and fuse as appropriate @@ -592,7 +593,7 @@ def fuse( fudge += 1 # Sanity check; don't go too deep if new levels introduce new edge dependencies - if ((num_nodes + fudge) / height <= ave_width and (no_new_edges or height < max_depth_new_edges)): + if (num_nodes + fudge) / height <= ave_width and (no_new_edges or height < max_depth_new_edges): # Perform substitutions as we go val = subs(dsk[parent], child_key, child_task) deps_parent.remove(child_key) @@ -607,27 +608,31 @@ def fuse( if children_stack: if no_new_edges: # Linear fuse - info_stack_append(( - parent, - val, - child_keys, - height, - width, - num_nodes, - fudge, - edges, - )) + info_stack_append( + ( + parent, + val, + child_keys, + height, + width, + num_nodes, + fudge, + edges, + ) + ) else: - info_stack_append(( - parent, - val, - child_keys, - height + 1, - width, - num_nodes + 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + val, + child_keys, + height + 1, + width, + num_nodes + 1, + fudge, + edges, + ) + ) else: rv[parent] = val break @@ -640,16 +645,18 @@ def fuse( if fudge > int(ave_width - 1): fudge = int(ave_width - 1) # This task *implicitly* depends on `edges` - info_stack_append(( - parent, - rv[parent], - [parent] if rename_keys else None, - 1, - width, - 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + rv[parent], + [parent] if rename_keys else None, + 1, + width, + 1, + fudge, + edges, + ) + ) else: break else: @@ -717,16 +724,18 @@ def fuse( fused_trees[parent] = child_keys if children_stack: - info_stack_append(( - parent, - val, - child_keys, - height + 1, - width, - num_nodes + 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + val, + child_keys, + height + 1, + width, + num_nodes + 1, + fudge, + edges, + ) + ) else: rv[parent] = val break @@ -743,16 +752,18 @@ def fuse( fudge = int(ave_width - 1) # key, task, height, width, number of nodes, fudge, set of edges # This task *implicitly* depends on `edges` - info_stack_append(( - parent, - rv[parent], - [parent] if rename_keys else None, - 1, - width, - 1, - fudge, - edges, - )) + info_stack_append( + ( + parent, + rv[parent], + [parent] if rename_keys else None, + 1, + width, + 1, + fudge, + edges, + ) + ) else: break # Traverse upwards @@ -774,7 +785,7 @@ def fuse( def _inplace_fuse_subgraphs(dsk, keys, dependencies, fused_trees, rename_keys): - """Subroutine of fuse.Mutates dsk, depenencies, and fused_trees inplace""" + """Subroutine of fuse.Mutates dsk, depenencies, and fused_trees inplace.""" # locate all members of linear chains child2parent = {} unfusible = set() @@ -828,7 +839,7 @@ def _inplace_fuse_subgraphs(dsk, keys, dependencies, fused_trees, rename_keys): # Create new task inkeys = tuple(inkeys_set) - dsk[outkey] = (SubgraphCallable(subgraph, outkey, inkeys), ) + inkeys + dsk[outkey] = (SubgraphCallable(subgraph, outkey, inkeys),) + inkeys # Mutate `fused_trees` if key renaming is needed (renaming done in fuse) if rename_keys: diff --git a/flash/core/serve/dag/order.py b/flash/core/serve/dag/order.py index 02ba374348..da096decb9 100644 --- a/flash/core/serve/dag/order.py +++ b/flash/core/serve/dag/order.py @@ -84,7 +84,7 @@ def order(dsk, dependencies=None): - """Order nodes in the task graph + """Order nodes in the task graph. This produces an ordering over our tasks that we use to break ties when executing. We do this ahead of time to reduce a bit of stress on the @@ -151,10 +151,9 @@ def order(dsk, dependencies=None): initial_stack_key = init_stack.__getitem__ def dependents_key(x): - """Choose a path from our starting task to our tactical goal + """Choose a path from our starting task to our tactical goal. - This path is connected to a large goal, but focuses on completing - a small goal and being memory efficient. + This path is connected to a large goal, but focuses on completing a small goal and being memory efficient. """ return ( # Focus on being memory-efficient @@ -165,7 +164,7 @@ def dependents_key(x): ) def dependencies_key(x): - """Choose which dependency to run as part of a reverse DFS + """Choose which dependency to run as part of a reverse DFS. This is very similar to both ``initial_stack_key``. """ @@ -196,7 +195,7 @@ def dependencies_key(x): ) def finish_now_key(x): - """ Determine the order of dependents that are ready to run and be released""" + """Determine the order of dependents that are ready to run and be released.""" return (-len(dependencies[x]), StrComparable(x)) # Computing this for all keys can sometimes be relatively expensive :( @@ -322,7 +321,7 @@ def finish_now_key(x): if len(deps) == 1: # Fast path! We trim down `deps` above hoping to reach here. - (dep, ) = deps + (dep,) = deps if not inner_stack: if add_to_inner_stack: inner_stack = [dep] @@ -566,7 +565,7 @@ def graph_metrics(dependencies, dependents, total_dependencies): key = current_pop() parents = dependents[key] if len(parents) == 1: - (parent, ) = parents + (parent,) = parents ( total_dependents, min_dependencies, @@ -604,7 +603,7 @@ def graph_metrics(dependencies, dependents, total_dependencies): def ndependencies(dependencies, dependents): - """Number of total data elements on which this key depends + """Number of total data elements on which this key depends. For each key we return the number of tasks that must be run for us to run this task. @@ -650,7 +649,7 @@ def ndependencies(dependencies, dependents): class StrComparable: - """Wrap object so that it defaults to string comparison + """Wrap object so that it defaults to string comparison. When comparing two objects of different types Python fails @@ -666,7 +665,7 @@ class StrComparable: False """ - __slots__ = ("obj", ) + __slots__ = ("obj",) def __init__(self, obj): self.obj = obj diff --git a/flash/core/serve/dag/rewrite.py b/flash/core/serve/dag/rewrite.py index 43c6dd021f..f85cff947e 100644 --- a/flash/core/serve/dag/rewrite.py +++ b/flash/core/serve/dag/rewrite.py @@ -4,7 +4,7 @@ def head(task): - """Return the top level node of a task""" + """Return the top level node of a task.""" if istask(task): return task[0] @@ -14,7 +14,7 @@ def head(task): def args(task): - """Get the arguments for the current task""" + """Get the arguments for the current task.""" if istask(task): return task[1:] @@ -58,8 +58,8 @@ def __iter__(self): def copy(self): """Copy the traverser in its current state. - This allows the traversal to be pushed onto a stack, for easy - backtracking.""" + This allows the traversal to be pushed onto a stack, for easy backtracking. + """ return Traverser(self.term, deque(self._stack)) @@ -79,14 +79,15 @@ def current(self): return head(self.term) def skip(self): - """Skip over all subterms of the current level in the traversal""" + """Skip over all subterms of the current level in the traversal.""" self.term = self._stack.pop() class Token: """A token object. - Used to express certain objects in the traversal of a task or pattern.""" + Used to express certain objects in the traversal of a task or pattern. + """ def __init__(self, name): self.name = name @@ -114,12 +115,12 @@ def __new__(cls, edges=None, patterns=None): @property def edges(self): - """A dictionary, where the keys are edges, and the values are nodes""" + """A dictionary, where the keys are edges, and the values are nodes.""" return self[0] @property def patterns(self): - """A list of all patterns that currently match at this node""" + """A list of all patterns that currently match at this node.""" return self[1] @@ -188,7 +189,7 @@ def _apply(self, sub_dict): return term def __str__(self): - return "RewriteRule({0}, {1}, {2})".format(self.lhs, self.rhs, self.vars) + return f"RewriteRule({self.lhs}, {self.rhs}, {self.vars})" def __repr__(self): return str(self) @@ -231,7 +232,7 @@ class RuleSet: """ def __init__(self, *rules): - """Create a `RuleSet` for a number of rules + """Create a `RuleSet` for a number of rules. Parameters ---------- @@ -281,7 +282,8 @@ def iter_matches(self, term): ------ Tuples of `(rule, subs)`, where `rule` is the rewrite rule being matched, and `subs` is a dictionary mapping the variables in the lhs - of the rule to their matching values in the term.""" + of the rule to their matching values in the term. + """ S = Traverser(term) for m, syms in _match(S, self._net): @@ -292,7 +294,7 @@ def iter_matches(self, term): yield rule, subs def _rewrite(self, term): - """Apply the rewrite rules in RuleSet to top level of term""" + """Apply the rewrite rules in RuleSet to top level of term.""" for rule, sd in self.iter_matches(term): # We use for (...) because it's fast in all cases for getting the @@ -352,7 +354,7 @@ def _top_level(net, term): def _bottom_up(net, term): if istask(term): - term = (head(term), ) + tuple(_bottom_up(net, t) for t in args(term)) + term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term)) elif isinstance(term, list): term = [_bottom_up(net, t) for t in args(term)] return net._rewrite(term) @@ -387,7 +389,7 @@ def _match(S, N): n = N.edges.get(VAR, None) if n: restore_state_flag = False - matches = matches + (S.term, ) + matches = matches + (S.term,) S.skip() N = n continue @@ -400,8 +402,8 @@ def _match(S, N): def _process_match(rule, syms): - """Process a match to determine if it is correct, and to find the correct - substitution that will convert the term into the pattern. + """Process a match to determine if it is correct, and to find the correct substitution that will convert the + term into the pattern. Parameters ---------- @@ -413,7 +415,8 @@ def _process_match(rule, syms): ------- A dictionary of {vars : subterms} describing the substitution to make the pattern equivalent with the term. Returns `None` if the match is - invalid.""" + invalid. + """ subs = {} varlist = rule._varlist diff --git a/flash/core/serve/dag/task.py b/flash/core/serve/dag/task.py index fa6ed0fd8e..94f132de66 100644 --- a/flash/core/serve/dag/task.py +++ b/flash/core/serve/dag/task.py @@ -41,12 +41,10 @@ def preorder_traversal(task): for item in task: if istask(item): - for i in preorder_traversal(item): - yield i + yield from preorder_traversal(item) elif isinstance(item, list): yield list - for i in preorder_traversal(item): - yield i + yield from preorder_traversal(item) else: yield item @@ -58,7 +56,7 @@ def lists_to_tuples(res, keys): def _execute_task(arg, cache): - """Do the actual work of collecting data and executing a function + """Do the actual work of collecting data and executing a function. Examples -------- @@ -134,7 +132,7 @@ def get(dsk: dict, out: Sequence[str], cache: dict = None, sortkeys: List[str] = def get_dependencies(dsk, key=None, task=no_default, as_list=False): - """Get the immediate tasks on which this task depends + """Get the immediate tasks on which this task depends. Examples -------- @@ -188,7 +186,7 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False): def get_deps(dsk): - """Get dependencies and dependents from task graph + """Get dependencies and dependents from task graph. Examples -------- @@ -222,8 +220,7 @@ def flatten(seq, container=list): else: for item in seq: if isinstance(item, container): - for item2 in flatten(item, container=container): - yield item2 + yield from flatten(item, container=container) else: yield item @@ -246,7 +243,7 @@ def reverse_dict(d): def subs(task, key, val): - """Perform a substitution on a task + """Perform a substitution on a task. Examples -------- @@ -289,8 +286,7 @@ def subs(task, key, val): def _toposort(dsk, keys=None, returncycle=False, dependencies=None): """Stack-based depth-first search traversal. - This is based on Tarjan's method for topological sorting - (see wikipedia for pseudocode). + This is based on Tarjan's method for topological sorting (see wikipedia for pseudocode). """ if keys is None: keys = dsk @@ -363,8 +359,7 @@ def toposort(dsk, dependencies=None): def getcycle(d, keys): - """Return a list of nodes that form a cycle if graph is not a DAG. - Returns an empty list if no cycle is found. + """Return a list of nodes that form a cycle if graph is not a DAG. Returns an empty list if no cycle is found. ``keys`` may be a single key or list of keys. Examples @@ -381,8 +376,8 @@ def getcycle(d, keys): def isdag(d, keys): - """Does graph form a directed acyclic graph when calculating keys? - ``keys`` may be a single key or list of keys. + """Does graph form a directed acyclic graph when calculating keys? ``keys`` may be a single key or list of + keys. Examples -------- @@ -399,9 +394,9 @@ def isdag(d, keys): class literal: - """A small serializable object to wrap literal values without copying""" + """A small serializable object to wrap literal values without copying.""" - __slots__ = ("data", ) + __slots__ = ("data",) def __init__(self, data): self.data = data @@ -410,16 +405,15 @@ def __repr__(self): return "literal" % type(self.data).__name__ def __reduce__(self): - return (literal, (self.data, )) + return (literal, (self.data,)) def __call__(self): return self.data def quote(x): - """Ensure that this value remains this value in a task graph - Some values in task graph take on special meaning. Sometimes we want to - ensure that our data is not interpreted but remains literal. + """Ensure that this value remains this value in a task graph Some values in task graph take on special meaning. + Sometimes we want to ensure that our data is not interpreted but remains literal. Examples -------- @@ -427,5 +421,5 @@ def quote(x): (literal,) """ if istask(x) or type(x) is list or type(x) is dict: - return (literal(x), ) + return (literal(x),) return x diff --git a/flash/core/serve/dag/visualize.py b/flash/core/serve/dag/visualize.py index 24b14ce51c..bc847d984a 100644 --- a/flash/core/serve/dag/visualize.py +++ b/flash/core/serve/dag/visualize.py @@ -37,7 +37,7 @@ def _dag_to_graphviz(dag, dependencies, request_data, response_data, *, no_optim g.node(request_name, request_name, shape="oval") with g.subgraph(name=f"cluster_{cluster}") as c: c.node(task_key, task_key, shape="rectangle") - c.edge(task_key, task_key[:-len(".serial")]) + c.edge(task_key, task_key[: -len(".serial")]) g.edge(request_name, task_key) @@ -48,13 +48,13 @@ def _dag_to_graphviz(dag, dependencies, request_data, response_data, *, no_optim def visualize( - tc: 'TaskComposition', + tc: "TaskComposition", fhandle: BytesIO = None, format: str = "png", *, no_optimization: bool = False, ): - """Visualize a graph""" + """Visualize a graph.""" dsk = tc.pre_optimization_dsk if no_optimization else tc.dsk dependencies, dependents = get_deps(dsk) g = _dag_to_graphviz( diff --git a/flash/core/serve/decorators.py b/flash/core/serve/decorators.py index ae647ef14d..5569707000 100644 --- a/flash/core/serve/decorators.py +++ b/flash/core/serve/decorators.py @@ -29,7 +29,7 @@ class UnboundMeta: @dataclass(unsafe_hash=True) class BoundMeta(UnboundMeta): - models: Union[List['Servable'], Tuple['Servable', ...], Dict[str, 'Servable']] + models: Union[List["Servable"], Tuple["Servable", ...], Dict[str, "Servable"]] uid: str = field(default_factory=lambda: uuid4().hex, init=False) out_attr_dict: ParameterContainer = field(default=None, init=False) inp_attr_dict: ParameterContainer = field(default=None, init=False) @@ -66,7 +66,7 @@ def __post_init__(self): ) @property - def connections(self) -> Sequence['Connection']: + def connections(self) -> Sequence["Connection"]: connections = [] for fld in fields(self.inp_attr_dict): connections.extend(getattr(self.inp_attr_dict, fld.name).connections) @@ -154,7 +154,6 @@ def expose(inputs: Dict[str, BaseType], outputs: Dict[str, BaseType]): _validate_expose_inputs_outputs_args(outputs) def wrapper(fn): - @wraps(fn) def wrapped(func): func.flashserve_meta = UnboundMeta(exposed=func, inputs=inputs, outputs=outputs) diff --git a/flash/core/serve/execution.py b/flash/core/serve/execution.py index e3ba5485f2..1546ff76d9 100644 --- a/flash/core/serve/execution.py +++ b/flash/core/serve/execution.py @@ -134,7 +134,7 @@ class UnprocessedTaskDask: def _process_initial( - endpoint_protocol: 'EndpointProtocol', components: Dict[str, 'ModelComponent'] + endpoint_protocol: "EndpointProtocol", components: Dict[str, "ModelComponent"] ) -> UnprocessedTaskDask: """Extract task dsk and payload / results keys and return computable form. @@ -154,22 +154,18 @@ def _process_initial( # mapping payload input keys -> serialized keys / tasks payload_dsk_key_map = { - payload_key: f"{input_key}.serial" - for payload_key, input_key in endpoint_protocol.dsk_input_key_map.items() + payload_key: f"{input_key}.serial" for payload_key, input_key in endpoint_protocol.dsk_input_key_map.items() } payload_input_tasks_dsk = { - input_dsk_key: (identity, payload_key) - for payload_key, input_dsk_key in payload_dsk_key_map.items() + input_dsk_key: (identity, payload_key) for payload_key, input_dsk_key in payload_dsk_key_map.items() } # mapping result keys -> serialize keys / tasks res_dsk_key_map = { - result_key: f"{output_key}.serial" - for result_key, output_key in endpoint_protocol.dsk_output_key_map.items() + result_key: f"{output_key}.serial" for result_key, output_key in endpoint_protocol.dsk_output_key_map.items() } result_output_tasks_dsk = { - result_key: (identity, output_dsk_key) - for result_key, output_dsk_key in res_dsk_key_map.items() + result_key: (identity, output_dsk_key) for result_key, output_dsk_key in res_dsk_key_map.items() } output_keys = list(res_dsk_key_map.keys()) @@ -198,10 +194,10 @@ def _process_initial( def build_composition( - endpoint_protocol: 'EndpointProtocol', - components: Dict[str, 'ModelComponent'], - connections: List['Connection'], -) -> 'TaskComposition': + endpoint_protocol: "EndpointProtocol", + components: Dict[str, "ModelComponent"], + connections: List["Connection"], +) -> "TaskComposition": r"""Build a composed graph. Notes on easy sources to introduce bugs. @@ -342,7 +338,7 @@ def _verify_no_cycles(dsk: Dict[str, tuple], out_keys: List[str], endpoint_name: ) -def connections_from_components_map(components: Dict[str, 'ModelComponent']) -> List[Dict[str, str]]: +def connections_from_components_map(components: Dict[str, "ModelComponent"]) -> List[Dict[str, str]]: dsk_connections = [] for con in flatten([comp._flashserve_meta_.connections for comp in components.values()]): # value of target key is mapped one-to-one from value of source @@ -350,7 +346,7 @@ def connections_from_components_map(components: Dict[str, 'ModelComponent']) -> return dsk_connections -def endpoint_protocol_content(ep_proto: 'EndpointProtocol') -> 'EndpointProtoJSON': +def endpoint_protocol_content(ep_proto: "EndpointProtocol") -> "EndpointProtoJSON": ep_proto_payload_dsk_key_map = valmap(lambda x: f"{x}.serial", ep_proto.dsk_input_key_map) ep_proto_result_key_dsk_map = valmap(lambda x: f"{x}.serial", ep_proto.dsk_output_key_map) @@ -362,7 +358,7 @@ def endpoint_protocol_content(ep_proto: 'EndpointProtocol') -> 'EndpointProtoJSO ) -def merged_dag_content(ep_proto: 'EndpointProtocol', components: Dict[str, 'ModelComponent']) -> 'MergedJSON': +def merged_dag_content(ep_proto: "EndpointProtocol", components: Dict[str, "ModelComponent"]) -> "MergedJSON": init = _process_initial(ep_proto, components) dsk_connections = connections_from_components_map(components) epjson = endpoint_protocol_content(ep_proto) @@ -376,7 +372,7 @@ def merged_dag_content(ep_proto: 'EndpointProtocol', components: Dict[str, 'Mode for request_name, task_key in init.payload_dsk_map.items(): cluster, *_ = task_key.split(".") - merged_proto[task_key[:-len(".serial")]].append(task_key) + merged_proto[task_key[: -len(".serial")]].append(task_key) merged_proto[task_key].append(request_name) merged_proto = dict(merged_proto) @@ -394,7 +390,7 @@ def merged_dag_content(ep_proto: 'EndpointProtocol', components: Dict[str, 'Mode ) -def component_dag_content(components: Dict[str, 'ModelComponent']) -> 'ComponentJSON': +def component_dag_content(components: Dict[str, "ModelComponent"]) -> "ComponentJSON": dsk_connections = connections_from_components_map(components) comp_dependencies, comp_dependents, comp_funcnames = {}, {}, {} diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index ea5ae85392..f52afe6382 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -10,7 +10,6 @@ class FlashInputs(BaseType): - def __init__( self, deserializer: Callable, @@ -25,7 +24,6 @@ def deserialize(self, data: str) -> Any: # pragma: no cover class FlashOutputs(BaseType): - def __init__( self, serializer: Callable, @@ -53,7 +51,6 @@ def build_flash_serve_model_component(model): data_pipeline = model.build_data_pipeline() class FlashServeModelComponent(ModelComponent): - def __init__(self, model): self.model = model self.model.eval() diff --git a/flash/core/serve/interfaces/http.py b/flash/core/serve/interfaces/http.py index 594dea1b7f..861ad32937 100644 --- a/flash/core/serve/interfaces/http.py +++ b/flash/core/serve/interfaces/http.py @@ -35,6 +35,7 @@ try: from typing import ForwardRef + RequestModel = ForwardRef("RequestModel") ResponseModel = ForwardRef("ResponseModel") except ImportError: @@ -47,7 +48,6 @@ def _build_endpoint( dsk_composition: TaskComposition, response_model: ResponseModel, ) -> Callable[[RequestModel], ResponseModel]: - def endpoint_fn(body: request_model): session = body.session if body.session else str(uuid.uuid4()) _res = get( @@ -67,7 +67,6 @@ def endpoint_fn(body: request_model): def _build_meta(Body: RequestModel) -> Callable[[], Dict[str, Any]]: - def meta() -> Dict[str, Any]: nonlocal Body return Body.schema() @@ -76,7 +75,6 @@ def meta() -> Dict[str, Any]: def _build_alive_check() -> Callable[[], Alive]: - def alive() -> Alive: return Alive.construct(alive=True) @@ -89,7 +87,6 @@ def _build_visualization( *, no_optimization: bool = False, ): - def endpoint_visualization(request: Request): nonlocal dsk_composition, templates, no_optimization with BytesIO() as f: @@ -104,8 +101,8 @@ def endpoint_visualization(request: Request): def _build_dag_json( - components: Dict[str, 'ModelComponent'], - ep_proto: Optional['EndpointProtocol'], + components: Dict[str, "ModelComponent"], + ep_proto: Optional["EndpointProtocol"], *, show_connected_components: bool = True, ): @@ -122,7 +119,7 @@ def dag_json(): return dag_json -def setup_http_app(composition: 'Composition', debug: bool) -> 'FastAPI': +def setup_http_app(composition: "Composition", debug: bool) -> "FastAPI": from flash import __version__ app = FastAPI( @@ -163,11 +160,13 @@ def setup_http_app(composition: 'Composition', debug: bool) -> 'FastAPI': name="components JSON DAG", summary="JSON representation of component DAG", response_model=ComponentJSON, - )(_build_dag_json( - components=composition.components, - ep_proto=None, - show_connected_components=False, - )) + )( + _build_dag_json( + components=composition.components, + ep_proto=None, + show_connected_components=False, + ) + ) for ep_name, ep_proto in composition.endpoint_protocols.items(): dsk = build_composition( @@ -221,9 +220,11 @@ def setup_http_app(composition: 'Composition', debug: bool) -> 'FastAPI': tags=[ep_name], summary="JSON representatino of DAG", response_model=MergedJSON, - )(_build_dag_json( - components=composition.components, - ep_proto=ep_proto, - show_connected_components=True, - )) + )( + _build_dag_json( + components=composition.components, + ep_proto=ep_proto, + show_connected_components=True, + ) + ) return app diff --git a/flash/core/serve/interfaces/models.py b/flash/core/serve/interfaces/models.py index 949aa06dc0..3b2503b866 100644 --- a/flash/core/serve/interfaces/models.py +++ b/flash/core/serve/interfaces/models.py @@ -12,6 +12,7 @@ try: from typing import ForwardRef + RequestModel = ForwardRef("RequestModel") ResponseModel = ForwardRef("ResponseModel") except ImportError: @@ -26,39 +27,37 @@ class Alive(BaseModel): class EndpointProtocol: - """Records the model classes used to define an endpoints request/response body - - The request / response body schemas are generated dynamically depending - on the endpoint + components passed into the class initializer. Component - inputs & outputs (as defined in `@expose` object decorations) dtype - method (`serialize` and `deserialize`) type hints are inspected in order to - constuct a specification unique to the endpoint, they are returned as - subclasses of pydantic ``BaseModel``. + """Records the model classes used to define an endpoints request/response body. + + The request / response body schemas are generated dynamically depending on the endpoint + components passed into the + class initializer. Component inputs & outputs (as defined in `@expose` object decorations) dtype method (`serialize` + and `deserialize`) type hints are inspected in order to constuct a specification unique to the endpoint, they are + returned as subclasses of pydantic ``BaseModel``. """ - def __init__(self, name: str, endpoint: 'Endpoint', components: Dict[str, 'ModelComponent']): + def __init__(self, name: str, endpoint: "Endpoint", components: Dict[str, "ModelComponent"]): self._name = name self._endpoint = endpoint self._component = components @property def name(self) -> str: - """Name assigned to the endpoint definition in the composition""" + """Name assigned to the endpoint definition in the composition.""" return self._name @property def route(self) -> str: - """Endpoint HTTP route""" + """Endpoint HTTP route.""" return self._endpoint.route @property def dsk_input_key_map(self) -> Dict[str, str]: - """Map of payload key name -> key to insert in dsk before execution""" + """Map of payload key name -> key to insert in dsk before execution.""" return self._endpoint.inputs @property def dsk_output_key_map(self): - """Map output key names -> dsk output key names""" + """Map output key names -> dsk output key names.""" return self._endpoint.outputs @property @@ -121,10 +120,7 @@ def request_model(self) -> RequestModel: RequestModel = create_model( f"{self.name.title()}_RequestModel", __module__=self.__class__.__module__, - **{ - "session": (Optional[str], None), - "payload": (payload_model, ...) - }, + **{"session": (Optional[str], None), "payload": (payload_model, ...)}, ) RequestModel.update_forward_refs() return RequestModel @@ -182,10 +178,7 @@ def response_model(self) -> ResponseModel: ResponseModel = create_model( f"{self.name.title()}_Response", __module__=self.__class__.__module__, - **{ - "session": (Optional[str], None), - "result": (results_model, ...) - }, + **{"session": (Optional[str], None), "result": (results_model, ...)}, ) ResponseModel.update_forward_refs() return ResponseModel diff --git a/flash/core/serve/server.py b/flash/core/serve/server.py index 8ea1e3902a..ced1cc5fc9 100644 --- a/flash/core/serve/server.py +++ b/flash/core/serve/server.py @@ -15,20 +15,17 @@ class ServerMixin: - """Start a server to serve a composition - - debug - If the server should be started up in debug mode. By default, False. - testing - If the server should return the ``app`` instance instead of blocking - the process (via running the ``app`` in ``uvicorn``). This is used - when taking advantage of a server ``TestClient``. By default, False + """Start a server to serve a composition. + + debug If the server should be started up in debug mode. By default, False. testing If the server should + return the ``app`` instance instead of blocking the process (via running the ``app`` in ``uvicorn``). This is + used when taking advantage of a server ``TestClient``. By default, False """ DEBUG: bool TESTING: bool - def http_app(self) -> 'FastAPI': + def http_app(self) -> "FastAPI": return setup_http_app(composition=self, debug=self.DEBUG) def serve(self, host: str = "127.0.0.1", port: int = 8000): diff --git a/flash/core/serve/types/base.py b/flash/core/serve/types/base.py index 17fe4c725b..ed2349af2a 100644 --- a/flash/core/serve/types/base.py +++ b/flash/core/serve/types/base.py @@ -46,24 +46,23 @@ def type_hints(self): @abc.abstractmethod def serialize(self, data): # pragma: no cover - """Serialize the incoming data to send it through the network""" + """Serialize the incoming data to send it through the network.""" raise NotImplementedError @abc.abstractmethod def deserialize(self, *args, **kwargs): # pragma: no cover - """Take the inputs from the network and deserilize/convert them them. Output from - this method will go to the exposed method as arguments. + """Take the inputs from the network and deserilize/convert them them. + + Output from this method will go to the exposed method as arguments. """ raise NotImplementedError def packed_deserialize(self, kwargs): """Unpacks data (assuming kwargs) and calls deserialize method of child class. - While it does not seem to be doing much, and always does one thing, the - benefit comes when building sophisticated datatypes (such as Repeated) - where the developer wants to dictate how the unpacking happens. For simple - cases like Image or Bbox etc, developer would never need to know the - existence of this. Task graph would never call deserialize directly - but always call this method. + While it does not seem to be doing much, and always does one thing, the benefit comes when building + sophisticated datatypes (such as Repeated) where the developer wants to dictate how the unpacking happens. For + simple cases like Image or Bbox etc, developer would never need to know the existence of this. Task graph would + never call deserialize directly but always call this method. """ return self.deserialize(**kwargs) diff --git a/flash/core/serve/types/image.py b/flash/core/serve/types/image.py index 31d714cdb4..82a82219ea 100644 --- a/flash/core/serve/types/image.py +++ b/flash/core/serve/types/image.py @@ -20,23 +20,25 @@ class Image(BaseType): Notes ----- - * The ``modes`` parameter can take on any one of the following values. + * The ``modes`` parameter can take on any one of the following values: .. code-block:: python - 1: 1, # (1-bit pixels, black and white, stored with one pixel per byte) - "L": 1, # (8-bit pixels, black and white) - "P": 1, # (8-bit pixels, mapped to any other mode using a color palette) - "RGB": 3, # (3x8-bit pixels, true color) - "RGBX": 4, # RGB with padding - "RGBA": 4, # (4x8-bit pixels, true color with transparency mask) - "RGBa": 3, # (3x8-bit pixels, true color with pre-multiplied alpha) - "CMYK": 4, # (4x8-bit pixels, color separation) - "YCbCr": 3, # (3x8-bit pixels, color video format) - "LAB": 3, # (3x8-bit pixels, the L*a*b color space) - "HSV": 3, # (3x8-bit pixels, Hue, Saturation, Value color space) - "I": 1, # (32-bit signed integer pixels) - "F": 1, # (32-bit floating point pixels) + { + 1: 1, # (1-bit pixels, black and white, stored with one pixel per byte) + "L": 1, # (8-bit pixels, black and white) + "P": 1, # (8-bit pixels, mapped to any other mode using a color palette) + "RGB": 3, # (3x8-bit pixels, true color) + "RGBX": 4, # RGB with padding + "RGBA": 4, # (4x8-bit pixels, true color with transparency mask) + "RGBa": 3, # (3x8-bit pixels, true color with pre-multiplied alpha) + "CMYK": 4, # (4x8-bit pixels, color separation) + "YCbCr": 3, # (3x8-bit pixels, color video format) + "LAB": 3, # (3x8-bit pixels, the L*a*b color space) + "HSV": 3, # (3x8-bit pixels, Hue, Saturation, Value color space) + "I": 1, # (32-bit signed integer pixels) + "F": 1, # (32-bit floating point pixels) + } """ height: Optional[int] = None diff --git a/flash/core/serve/types/label.py b/flash/core/serve/types/label.py index 61a634154b..e44ad3cc5e 100644 --- a/flash/core/serve/types/label.py +++ b/flash/core/serve/types/label.py @@ -9,8 +9,7 @@ @dataclass(unsafe_hash=True) class Label(BaseType): - """ - Type specifically made for labels that are mapped to a key. + """Type specifically made for labels that are mapped to a key. Parameters ---------- @@ -30,11 +29,10 @@ def __post_init__(self): if self.classes is None: if self.path is None: raise ValueError( - "Must provide either classes as a list or " - "path to a text file that contains classes" + "Must provide either classes as a list or " "path to a text file that contains classes" ) with Path(self.path).open(mode="r") as f: - self.classes = tuple([item.strip() for item in f.readlines()]) + self.classes = tuple(item.strip() for item in f.readlines()) if isinstance(self.classes, dict): self._reverse_map = {} for key, value in self.classes.items(): diff --git a/flash/core/serve/types/repeated.py b/flash/core/serve/types/repeated.py index d6def4347b..5efa86902b 100644 --- a/flash/core/serve/types/repeated.py +++ b/flash/core/serve/types/repeated.py @@ -50,7 +50,7 @@ def __post_init__(self): def deserialize(self, *args: Dict) -> Tuple[Tensor, ...]: if (self.max_len is not None) and (len(args) > self.max_len): raise ValueError(f"len(arg)={len(args)} > self.max_len={self.max_len}") - return tuple((self.dtype.deserialize(**item) for item in args)) + return tuple(self.dtype.deserialize(**item) for item in args) def packed_deserialize(self, args): """Arguments are positional arguments for deserialize, unlike other datatypes.""" @@ -59,4 +59,4 @@ def packed_deserialize(self, args): def serialize(self, args: Sequence) -> Tuple[Any, ...]: if (self.max_len is not None) and (len(args) > self.max_len): raise ValueError(f"len(arg)={len(args)} > self.max_len={self.max_len}") - return tuple((self.dtype.serialize(item) for item in args)) + return tuple(self.dtype.serialize(item) for item in args) diff --git a/flash/core/serve/types/table.py b/flash/core/serve/types/table.py index 22e3e57e9a..5b993e7c57 100644 --- a/flash/core/serve/types/table.py +++ b/flash/core/serve/types/table.py @@ -65,8 +65,7 @@ def deserialize(self, features: Dict[Union[int, str], Dict[int, Any]]): df = pd.DataFrame.from_dict(features) if len(self.column_names) != len(df.columns) or not np.all(df.columns == self.column_names): raise RuntimeError( - f"Failed to validate column names. \nExpected: " - f"{self.column_names}\nReceived: {list(df.columns)}" + f"Failed to validate column names. \nExpected: " f"{self.column_names}\nReceived: {list(df.columns)}" ) # TODO: This strict type checking needs to be changed when numpy arrays are returned if df.values.dtype.name not in allowed_types: diff --git a/flash/core/serve/types/text.py b/flash/core/serve/types/text.py index 287307e40b..9ac5f08bcc 100644 --- a/flash/core/serve/types/text.py +++ b/flash/core/serve/types/text.py @@ -9,8 +9,7 @@ @dataclass(unsafe_hash=True) class Text(BaseType): - """ - Type for converting string to tensor and back + """Type for converting string to tensor and back. Parameters ---------- diff --git a/flash/core/serve/utils.py b/flash/core/serve/utils.py index 511d44a76e..94ea9690cb 100644 --- a/flash/core/serve/utils.py +++ b/flash/core/serve/utils.py @@ -7,7 +7,7 @@ def fn_outputs_to_keyed_map(serialize_fn_out_keys, fn_output) -> Dict[str, Any]: - """ "convert outputs of a function to a dict of `{result_name: values}` + """convert outputs of a function to a dict of `{result_name: values}` accepts function outputs which are sequence, dict, or object. """ @@ -20,7 +20,7 @@ def fn_outputs_to_keyed_map(serialize_fn_out_keys, fn_output) -> Dict[str, Any]: def download_file(url: str, *, download_path: Optional[Path] = None) -> str: - """Download to cwd with filename as last part of address, return filepath + """Download to cwd with filename as last part of address, return filepath. Returns ------- @@ -49,8 +49,7 @@ def download_file(url: str, *, download_path: Optional[Path] = None) -> str: def _module_available(module_path: str) -> bool: - """ - Check if a path is available in your environment + """Check if a path is available in your environment. >>> _module_available('os') True diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 44faef2810..e376e3316b 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -32,8 +32,8 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): - """Modified version of ``pytorch_lightning.utilities.argparse.from_argparse_args`` which populates ``valid_kwargs`` - from ``pytorch_lightning.Trainer``.""" + """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates + ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" if isinstance(args, ArgumentParser): args = cls.parse_argparser(args) @@ -48,8 +48,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): def _defaults_from_env_vars(fn: Callable) -> Callable: - """Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``. Required to fix - build error in readthedocs.""" + """Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``. + + Required to fix build error in readthedocs. + """ @wraps(fn) def insert_env_defaults(self, *args, **kwargs): @@ -70,7 +72,6 @@ def insert_env_defaults(self, *args, **kwargs): class Trainer(PlTrainer): - @_defaults_from_env_vars def __init__(self, *args, serve_sanity_check: bool = False, **kwargs): if flash._IS_TESTING: @@ -164,9 +165,7 @@ def finetune( return super().fit(model, train_dataloader, val_dataloaders, datamodule) def _resolve_callbacks(self, model, strategy): - """ - This function is used to select the `BaseFinetuning` to be used for finetuning. - """ + """This function is used to select the `BaseFinetuning` to be used for finetuning.""" if strategy is not None and not isinstance(strategy, (str, BaseFinetuning)): raise MisconfigurationException( "strategy should be a ``pytorch_lightning.callbacks.BaseFinetuning``" @@ -186,7 +185,8 @@ def _resolve_callbacks(self, model, strategy): if strategy is not None: rank_zero_warn( "The model contains a default finetune callback. The provided {strategy} will be overriden.\n" - " HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning + " HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", + UserWarning, ) callback = model_callback else: @@ -196,10 +196,8 @@ def _resolve_callbacks(self, model, strategy): @staticmethod def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: - """ - This function keeps only 1 instance of each callback type, - extending new_callbacks with old_callbacks - """ + """This function keeps only 1 instance of each callback type, extending new_callbacks with + old_callbacks.""" if len(new_callbacks) == 0: return old_callbacks new_callbacks_types = {type(c) for c in new_callbacks} @@ -210,12 +208,15 @@ def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: @classmethod def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser: + """See :func:`pytorch_lightning.utilities.argparse.add_argparse_args`.""" # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return add_argparse_args(PlTrainer, *args, **kwargs) @classmethod - def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer': + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> "Trainer": + """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates + ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return from_argparse_args(Trainer, args, **kwargs) diff --git a/flash/core/utilities/apply_func.py b/flash/core/utilities/apply_func.py index af35c39e44..27e2d34960 100644 --- a/flash/core/utilities/apply_func.py +++ b/flash/core/utilities/apply_func.py @@ -28,10 +28,8 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map def _is_overriden(method_name: str, instance: object, parent: Type[object]) -> bool: - """ - Cropped Version of - https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py - """ + """Cropped Version of https://github.com/PyTorchLightning/pytorch- + lightning/blob/master/pytorch_lightning/utilities/model_helpers.py.""" if not hasattr(instance, method_name): return False diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py new file mode 100644 index 0000000000..7cfb341342 --- /dev/null +++ b/flash/core/utilities/flash_cli.py @@ -0,0 +1,200 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import functools +import inspect +from functools import wraps +from inspect import Parameter, signature +from typing import Any, Callable, List, Optional, Set, Type + +import pytorch_lightning as pl +from jsonargparse import ArgumentParser +from jsonargparse.signatures import get_class_signature_functions + +import flash +from flash.core.data.data_source import DefaultDataSources +from flash.core.utilities.lightning_cli import class_from_function, LightningCLI + + +def drop_kwargs(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + # Override signature + sig = signature(func) + sig = sig.replace( + parameters=tuple(p for p in sig.parameters.values() if p.kind is not p.VAR_KEYWORD and p.name != "self") + ) + if inspect.isclass(func): + sig = sig.replace(return_annotation=func) + wrapper.__signature__ = sig + + return wrapper + + +def make_args_optional(cls, args: Set[str]): + @wraps(cls) + def wrapper(*args, **kwargs): + return cls(*args, **kwargs) + + # Override signature + sig = signature(cls) + parameters = [p for p in sig.parameters.values() if p.name not in args or p.default != p.empty] + filtered_parameters = [p for p in sig.parameters.values() if p.name in args and p.default == p.empty] + + index = [i for i, p in enumerate(parameters) if p.kind == p.VAR_KEYWORD] + if index == []: + index = len(parameters) + else: + index = index[0] + + for p in filtered_parameters: + new_parameter = Parameter(p.name, p.POSITIONAL_OR_KEYWORD, default=None, annotation=Optional[p.annotation]) + parameters.insert(index, new_parameter) + + sig = sig.replace(parameters=parameters, return_annotation=cls) + wrapper.__signature__ = sig + + return wrapper + + +def get_overlapping_args(func_a, func_b) -> Set[str]: + func_a = get_class_signature_functions([func_a])[0][1] + func_b = get_class_signature_functions([func_b])[0][1] + return set(inspect.signature(func_a).parameters.keys() & inspect.signature(func_b).parameters.keys()) + + +class FlashCLI(LightningCLI): + def __init__( + self, + model_class: Type[pl.LightningModule], + datamodule_class: Type["flash.DataModule"], + trainer_class: Type[pl.Trainer] = flash.Trainer, + default_datamodule_builder: Optional[Callable] = None, + additional_datamodule_builders: Optional[List[Callable]] = None, + default_arguments=None, + finetune=True, + datamodule_attributes=None, + **kwargs: Any, + ) -> None: + """Flash's extension of the :class:`pytorch_lightning.utilities.cli.LightningCLI` + + Args: + model_class: The :class:`pytorch_lightning.LightningModule` class to train on. + datamodule_class: The :class:`~flash.data.data_module.DataModule` class. + trainer_class: An optional extension of the :class:`pytorch_lightning.Trainer` class. + trainer_fn: The trainer function to run. + datasource: Use this if your ``DataModule`` is created using a classmethod. Any of: + - ``None``. The ``datamodule_class.__init__`` signature will be used. + - ``str``. One of :class:`~flash.data.data_source.DefaultDataSources`. This will use the signature of + the corresponding ``DataModule.from_*`` method. + - ``Callable``. A custom method. + kwargs: See the parent arguments + """ + if datamodule_attributes is None: + datamodule_attributes = {"num_classes"} + self.datamodule_attributes = datamodule_attributes + + self.default_datamodule_builder = default_datamodule_builder + self.additional_datamodule_builders = additional_datamodule_builders or [] + self.default_arguments = default_arguments or {} + self.finetune = finetune + + model_class = make_args_optional(model_class, self.datamodule_attributes) + self.local_datamodule_class = datamodule_class + + self._subcommand_builders = {} + + super().__init__(drop_kwargs(model_class), datamodule_class=None, trainer_class=trainer_class, **kwargs) + + @contextlib.contextmanager + def patch_default_subcommand(self): + parse_common = self.parser._parse_common + + if self.default_datamodule_builder is not None: + + @functools.wraps(parse_common) + def wrapper(cfg, *args, **kwargs): + if "subcommand" not in cfg or cfg["subcommand"] is None: + cfg["subcommand"] = self.default_datamodule_builder.__name__ + return parse_common(cfg, *args, **kwargs) + + self.parser._parse_common = wrapper + + yield + + self.parser._parse_common = parse_common + + def parse_arguments(self) -> None: + with self.patch_default_subcommand(): + super().parse_arguments() + + def add_arguments_to_parser(self, parser) -> None: + subcommands = parser.add_subcommands() + + data_sources = self.local_datamodule_class.preprocess_cls().available_data_sources() + + for data_source in data_sources: + if isinstance(data_source, DefaultDataSources): + data_source = data_source.value + if hasattr(self.local_datamodule_class, f"from_{data_source}"): + self.add_subcommand_from_function( + subcommands, getattr(self.local_datamodule_class, f"from_{data_source}") + ) + + for datamodule_builder in self.additional_datamodule_builders: + self.add_subcommand_from_function(subcommands, datamodule_builder) + + if self.default_datamodule_builder is not None: + self.add_subcommand_from_function(subcommands, self.default_datamodule_builder) + + parser.set_defaults(self.default_arguments) + + def add_subcommand_from_function(self, subcommands, function, function_name=None): + subcommand = ArgumentParser() + datamodule_function = class_from_function(drop_kwargs(function)) + preprocess_function = class_from_function(drop_kwargs(self.local_datamodule_class.preprocess_cls)) + subcommand.add_class_arguments(datamodule_function, fail_untyped=False) + subcommand.add_class_arguments( + preprocess_function, fail_untyped=False, skip=get_overlapping_args(datamodule_function, preprocess_function) + ) + subcommand_name = function_name or function.__name__ + subcommands.add_subcommand(subcommand_name, subcommand) + self._subcommand_builders[subcommand_name] = function + + def instantiate_classes(self) -> None: + """Instantiates the classes using settings from self.config.""" + sub_config = self.config.get("subcommand") + self.datamodule = self._subcommand_builders[sub_config](**self.config.get(sub_config)) + + for datamodule_attribute in self.datamodule_attributes: + if datamodule_attribute in self.config["model"]: + if getattr(self.datamodule, datamodule_attribute, None) is not None: + self.config["model"][datamodule_attribute] = getattr(self.datamodule, datamodule_attribute) + self.config_init = self.parser.instantiate_classes(self.config) + self.model = self.config_init["model"] + self.instantiate_trainer() + + def prepare_fit_kwargs(self): + super().prepare_fit_kwargs() + if self.finetune: + # TODO: expose the strategy arguments? + self.fit_kwargs["strategy"] = "freeze" + + def fit(self) -> None: + if self.finetune: + self.trainer.finetune(**self.fit_kwargs) + else: + self.trainer.fit(**self.fit_kwargs) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 8632802001..1a7be19e05 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""General utilities""" import functools import importlib import operator import types from importlib.util import find_spec +from typing import Callable, List, Union +from warnings import warn from pkg_resources import DistributionNotFound @@ -27,8 +28,7 @@ def _module_available(module_path: str) -> bool: - """ - Check if a path is available in your environment + """Check if a path is available in your environment. >>> _module_available('os') True @@ -43,11 +43,13 @@ def _module_available(module_path: str) -> bool: except ModuleNotFoundError: # Python 3.7+ return False + except ValueError: + # Sometimes __spec__ can be None and gives a ValueError + return True def _compare_version(package: str, op, version) -> bool: - """ - Compare package version with some requirements + """Compare package version with some requirements. >>> _compare_version("torch", operator.ge, "0.1") True @@ -59,7 +61,7 @@ def _compare_version(package: str, op, version) -> bool: try: pkg_version = Version(pkg.__version__) except TypeError: - # this is mock by sphinx, so it shall return True ro generate all summaries + # this is mock by sphinx, so it shall return True to generate all summaries return True return op(pkg_version, Version(version)) @@ -84,54 +86,122 @@ def _compare_version(package: str, op, version) -> bool: _CYTOOLZ_AVAILABLE = _module_available("cytoolz") _UVICORN_AVAILABLE = _module_available("uvicorn") _PIL_AVAILABLE = _module_available("PIL") +_OPEN3D_AVAILABLE = _module_available("open3d") +_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") +_SOUNDFILE_AVAILABLE = _module_available("soundfile") +_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") +_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") +_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") +_TORCHAUDIO_AVAILABLE = _module_available("torchaudio") +_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score") +_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") +_DATASETS_AVAILABLE = _module_available("datasets") +_ICEVISION_AVAILABLE = _module_available("icevision") +_ICEDATA_AVAILABLE = _module_available("icedata") +_TORCH_ORT_AVAILABLE = _module_available("torch_ort") + +if _PIL_AVAILABLE: + from PIL import Image +else: + + class MetaImage(type): + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + + cls._Image = None + + @property + def Image(cls): + warn("Mock object called due to missing PIL library. Please use \"pip install 'lightning-flash[image]'\".") + return cls._Image + + class Image(metaclass=MetaImage): + pass + if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") -_TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE +_TEXT_AVAILABLE = all( + [ + _TRANSFORMERS_AVAILABLE, + _ROUGE_SCORE_AVAILABLE, + _SENTENCEPIECE_AVAILABLE, + _DATASETS_AVAILABLE, + ] +) _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE _VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE -_IMAGE_AVAILABLE = all([ - _TORCHVISION_AVAILABLE, - _TIMM_AVAILABLE, - _PIL_AVAILABLE, - _KORNIA_AVAILABLE, - _MATPLOTLIB_AVAILABLE, - _COCO_AVAILABLE, - _FIFTYONE_AVAILABLE, - _PYSTICHE_AVAILABLE, -]) +_IMAGE_AVAILABLE = all( + [ + _TORCHVISION_AVAILABLE, + _TIMM_AVAILABLE, + _PIL_AVAILABLE, + _KORNIA_AVAILABLE, + _PYSTICHE_AVAILABLE, + _SEGMENTATION_MODELS_AVAILABLE, + _ICEVISION_AVAILABLE, + _ICEDATA_AVAILABLE, + ] +) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE +_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE +_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _SOUNDFILE_AVAILABLE, _TRANSFORMERS_AVAILABLE]) +_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE _EXTRAS_AVAILABLE = { - 'image': _IMAGE_AVAILABLE, - 'tabular': _TABULAR_AVAILABLE, - 'text': _TEXT_AVAILABLE, - 'video': _VIDEO_AVAILABLE, - 'serve': _SERVE_AVAILABLE, + "image": _IMAGE_AVAILABLE, + "tabular": _TABULAR_AVAILABLE, + "text": _TEXT_AVAILABLE, + "video": _VIDEO_AVAILABLE, + "pointcloud": _POINTCLOUD_AVAILABLE, + "serve": _SERVE_AVAILABLE, + "audio": _AUDIO_AVAILABLE, + "graph": _GRAPH_AVAILABLE, } -def _requires_extras(extras: str): +def _requires( + module_paths: Union[str, List], + module_available: Callable[[str], bool], + formatter: Callable[[List[str]], str], +): + + if not isinstance(module_paths, list): + module_paths = [module_paths] def decorator(func): + if not all(module_available(module_path) for module_path in module_paths): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if not _EXTRAS_AVAILABLE[extras]: + @functools.wraps(func) + def wrapper(*args, **kwargs): raise ModuleNotFoundError( - f"Required dependencies not available. Please run: pip install 'lightning-flash[{extras}]'" + f"Required dependencies not available. Please run: pip install {formatter(module_paths)}" ) - return func(*args, **kwargs) - return wrapper + return wrapper + else: + return func return decorator +def requires(module_paths: Union[str, List]): + return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths)) + + +def requires_extras(extras: Union[str, List]): + return _requires( + extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'" + ) + + +def example_requires(extras: Union[str, List[str]]): + return requires_extras(extras)(lambda: None)() + + def lazy_import(module_name, callback=None): - """Returns a proxy module object that will lazily import the given module - the first time it is used. + """Returns a proxy module object that will lazily import the given module the first time it is used. Example usage:: @@ -155,8 +225,7 @@ def lazy_import(module_name, callback=None): class LazyModule(types.ModuleType): - """Proxy module that lazily imports the underlying module the first time it - is actually used. + """Proxy module that lazily imports the underlying module the first time it is actually used. Args: module_name: the fully-qualified module name to import diff --git a/flash/core/utilities/isinstance.py b/flash/core/utilities/isinstance.py new file mode 100644 index 0000000000..4eed928d24 --- /dev/null +++ b/flash/core/utilities/isinstance.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def _typed_isinstance(__object, __class_or_tuple): + return isinstance(__object, getattr(__class_or_tuple, "__origin__", __class_or_tuple)) + + +try: + from torch.jit import isinstance as _isinstance +except ImportError: + _isinstance = _typed_isinstance diff --git a/flash/core/utilities/lightning_cli.py b/flash/core/utilities/lightning_cli.py new file mode 100644 index 0000000000..1b5170b88f --- /dev/null +++ b/flash/core/utilities/lightning_cli.py @@ -0,0 +1,485 @@ +# Adapted from the Lightning CLI: +# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/cli.py +import inspect +import os +import warnings +from argparse import Namespace +from functools import wraps +from types import MethodType +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union + +from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode +from jsonargparse.signatures import ClassFromFunctionBase +from jsonargparse.typehints import ClassType +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple +from torch.optim import Optimizer + +from flash.core.data.data_module import DataModule + +set_config_read_mode(fsspec_enabled=True) + + +def class_from_function(func: Callable[..., ClassType]) -> Type[ClassType]: + """Creates a dynamic class which if instantiated is equivalent to calling func. + + Args: + func: A function that returns an instance of a class. It must have a return type annotation. + """ + + @wraps(func) + def __new__(cls, *args, **kwargs): + return func(*args, **kwargs) + + return_type = inspect.signature(func).return_annotation + if isinstance(return_type, str): + if return_type == "DataModule": + return_type = DataModule + + class ClassFromFunction(return_type, ClassFromFunctionBase): # type: ignore + pass + + ClassFromFunction.__new__ = __new__ # type: ignore + ClassFromFunction.__doc__ = func.__doc__ + ClassFromFunction.__name__ = func.__name__ + + return ClassFromFunction + + +class LightningArgumentParser(ArgumentParser): + """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" + + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: + """Initialize argument parser that supports configuration file input. + + For full details of accepted arguments see `ArgumentParser.__init__ + `_. + """ + super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) + self.add_argument( + "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." + ) + self.callback_keys: List[str] = [] + self.optimizers_and_lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + + def add_lightning_class_args( + self, + lightning_class: Union[ + Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], + Type[Trainer], + Type[LightningModule], + Type[LightningDataModule], + Type[Callback], + ], + nested_key: str, + subclass_mode: bool = False, + ) -> List[str]: + """Adds arguments from a lightning class to a nested key of the parser. + + Args: + lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. + nested_key: Name of the nested namespace to store arguments. + subclass_mode: Whether allow any subclass of the given class. + """ + if callable(lightning_class) and not inspect.isclass(lightning_class): + lightning_class = class_from_function(lightning_class) + + if inspect.isclass(lightning_class) and issubclass( + cast(type, lightning_class), (Trainer, LightningModule, LightningDataModule, Callback) + ): + if issubclass(cast(type, lightning_class), Callback): + self.callback_keys.append(nested_key) + if subclass_mode: + return self.add_subclass_arguments(lightning_class, nested_key, required=True) + return self.add_class_arguments( + lightning_class, + nested_key, + fail_untyped=False, + instantiate=not issubclass(cast(type, lightning_class), Trainer), + ) + raise MisconfigurationException( + f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " + "Trainer, LightningModule, LightningDataModule, or Callback." + ) + + def add_optimizer_args( + self, + optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]], + nested_key: str = "optimizer", + link_to: str = "AUTOMATIC", + ) -> None: + """Adds arguments from an optimizer class to a nested key of the parser. + + Args: + optimizer_class: Any subclass of torch.optim.Optimizer. + nested_key: Name of the nested namespace to store arguments. + link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ + if isinstance(optimizer_class, tuple): + assert all(issubclass(o, Optimizer) for o in optimizer_class) + else: + assert issubclass(optimizer_class, Optimizer) + kwargs = { + "instantiate": False, + "fail_untyped": False, + "skip": {"params"}, + } + if isinstance(optimizer_class, tuple): + self.add_subclass_arguments(optimizer_class, nested_key, required=True, **kwargs) + else: + self.add_class_arguments(optimizer_class, nested_key, **kwargs) + self.optimizers_and_lr_schedulers[nested_key] = (optimizer_class, link_to) + + def add_lr_scheduler_args( + self, + lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]], + nested_key: str = "lr_scheduler", + link_to: str = "AUTOMATIC", + ) -> None: + """Adds arguments from a learning rate scheduler class to a nested key of the parser. + + Args: + lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. + nested_key: Name of the nested namespace to store arguments. + link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ + if isinstance(lr_scheduler_class, tuple): + assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) + else: + assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) + kwargs = { + "instantiate": False, + "fail_untyped": False, + "skip": {"optimizer"}, + } + if isinstance(lr_scheduler_class, tuple): + self.add_subclass_arguments(lr_scheduler_class, nested_key, required=True, **kwargs) + else: + self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) + self.optimizers_and_lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + + +class SaveConfigCallback(Callback): + """Saves a LightningCLI config to the log_dir when training starts. + + Raises: + RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ + + def __init__( + self, + parser: LightningArgumentParser, + config: Union[Namespace, Dict[str, Any]], + config_filename: str, + overwrite: bool = False, + ) -> None: + self.parser = parser + self.config = config + self.config_filename = config_filename + self.overwrite = overwrite + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + # save the config in `setup` because (1) we want it to save regardless of the trainer function run + # and we want to save before processes are spawned + log_dir = trainer.log_dir + assert log_dir is not None + config_path = os.path.join(log_dir, self.config_filename) + if not self.overwrite and os.path.isfile(config_path): + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." + ) + if trainer.is_global_zero: + # save only on rank zero to avoid race conditions on DDP. + # the `log_dir` needs to be created as we rely on the logger to do it usually + # but it hasn't logged anything at this point + get_filesystem(log_dir).makedirs(log_dir, exist_ok=True) + self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite) + + def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: + # `ArgumentParser` is un-pickleable. Drop it + return ( + self.__class__, + (None, self.config, self.config_filename), + {}, + ) + + +class LightningCLI: + """Implementation of a configurable command line tool for pytorch-lightning.""" + + def __init__( + self, + model_class: Union[Type[LightningModule], Callable[..., LightningModule]], + datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, + save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, + save_config_filename: str = "config.yaml", + save_config_overwrite: bool = False, + trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, + trainer_defaults: Dict[str, Any] = None, + seed_everything_default: int = None, + description: str = "pytorch-lightning trainer command line tool", + env_prefix: str = "PL", + env_parse: bool = False, + parser_kwargs: Dict[str, Any] = None, + subclass_mode_model: bool = False, + subclass_mode_data: bool = False, + ) -> None: + """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which + are called / instantiated using a parsed configuration file and / or command line args and then runs + trainer.fit. Parsing of configuration from environment variables can be enabled by setting + ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual + settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. + + Example, first implement the ``trainer.py`` tool as:: + + from mymodels import MyModel + from pytorch_lightning.utilities.cli import LightningCLI + LightningCLI(MyModel) + + Then in a shell, run the tool with the desired configuration:: + + $ python trainer.py --print_config > config.yaml + $ nano config.yaml # modify the config as desired + $ python trainer.py --cfg config.yaml + + .. warning:: ``LightningCLI`` is in beta and subject to change. + + Args: + model_class: :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a callable + which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when called. + datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a + callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when + called. + save_config_callback: A callback class to save the training config. + save_config_filename: Filename for the config file. + save_config_overwrite: Whether to overwrite an existing config file. + trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a + callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. + trainer_defaults: Set to override Trainer defaults or add persistent callbacks. + seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` + seed argument. + description: Description of the tool shown when running ``--help``. + env_prefix: Prefix for environment variables. + env_parse: Whether environment variable parsing is enabled. + parser_kwargs: Additional arguments to instantiate LightningArgumentParser. + subclass_mode_model: Whether model can be any `subclass + `_ + of the given class. + subclass_mode_data: Whether datamodule can be any `subclass + `_ + of the given class. + """ + self.model_class = model_class + self.datamodule_class = datamodule_class + self.save_config_callback = save_config_callback + self.save_config_filename = save_config_filename + self.save_config_overwrite = save_config_overwrite + self.trainer_class = trainer_class + self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults + self.seed_everything_default = seed_everything_default + self.subclass_mode_model = subclass_mode_model + self.subclass_mode_data = subclass_mode_data + self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs + self.parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) + + self.init_parser() + self.add_core_arguments_to_parser() + self.add_arguments_to_parser(self.parser) + self.link_optimizers_and_lr_schedulers() + self.parse_arguments() + if self.config["seed_everything"] is not None: + seed_everything(self.config["seed_everything"], workers=True) + self.before_instantiate_classes() + self.instantiate_classes() + self.add_configure_optimizers_method_to_model() + self.prepare_fit_kwargs() + self.before_fit() + self.fit() + self.after_fit() + + def init_parser(self) -> None: + """Method that instantiates the argument parser.""" + self.parser = LightningArgumentParser(**self.parser_kwargs) + + def add_core_arguments_to_parser(self) -> None: + """Adds arguments from the core classes to the parser.""" + self.parser.add_argument( + "--seed_everything", + type=Optional[int], + default=self.seed_everything_default, + help="Set to an int to run seed_everything with this value before classes instantiation", + ) + self.parser.add_lightning_class_args(self.trainer_class, "trainer") + trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} + self.parser.set_defaults(trainer_defaults) + self.parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) + if self.datamodule_class is not None: + self.parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data) + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Implement to add extra arguments to parser or link arguments. + + Args: + parser: The argument parser object to which arguments can be added + """ + + def link_optimizers_and_lr_schedulers(self) -> None: + """Creates argument links for optimizers and lr_schedulers that specified a link_to.""" + for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items(): + if link_to == "AUTOMATIC": + continue + if isinstance(class_type, tuple): + self.parser.link_arguments(key, link_to) + else: + add_class_path = _add_class_path_generator(class_type) + self.parser.link_arguments(key, link_to, compute_fn=add_class_path) + + def parse_arguments(self) -> None: + """Parses command line arguments and stores it in self.config.""" + self.config = self.parser.parse_args() + + def before_instantiate_classes(self) -> None: + """Implement to run some code before instantiating the classes.""" + + def instantiate_classes(self) -> None: + """Instantiates the classes using settings from self.config.""" + self.config_init = self.parser.instantiate_classes(self.config) + self.datamodule = self.config_init.get("data") + self.model = self.config_init["model"] + self.instantiate_trainer() + + def instantiate_trainer(self) -> None: + """Instantiates the trainer using self.config_init['trainer']""" + if self.config_init["trainer"].get("callbacks") is None: + self.config_init["trainer"]["callbacks"] = [] + callbacks = [self.config_init[c] for c in self.parser.callback_keys] + self.config_init["trainer"]["callbacks"].extend(callbacks) + if "callbacks" in self.trainer_defaults: + if isinstance(self.trainer_defaults["callbacks"], list): + self.config_init["trainer"]["callbacks"].extend(self.trainer_defaults["callbacks"]) + else: + self.config_init["trainer"]["callbacks"].append(self.trainer_defaults["callbacks"]) + if self.save_config_callback and not self.config_init["trainer"]["fast_dev_run"]: + config_callback = self.save_config_callback( + self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite + ) + self.config_init["trainer"]["callbacks"].append(config_callback) + self.trainer = self.trainer_class(**self.config_init["trainer"]) + + def add_configure_optimizers_method_to_model(self) -> None: + """Adds to the model an automatically generated configure_optimizers method. + + If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a + `configure_optimizers` method is automatically implemented in the model class. + """ + + def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: + automatic = [] + for key, (base_class, link_to) in self.parser.optimizers_and_lr_schedulers.items(): + if not isinstance(base_class, tuple): + base_class = (base_class,) + if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): + automatic.append(key) + return automatic + + optimizers = get_automatic(Optimizer) + lr_schedulers = get_automatic(LRSchedulerTypeTuple) + + if len(optimizers) == 0: + return + + if len(optimizers) > 1 or len(lr_schedulers) > 1: + raise MisconfigurationException( + f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " + f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " + "is expected to link the argument groups and implement `configure_optimizers`, see " + "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" + "#optimizers-and-learning-rate-schedulers" + ) + + if is_overridden("configure_optimizers", self.model): + warnings.warn( + f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " + f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." + ) + + optimizer_class = self.parser.optimizers_and_lr_schedulers[optimizers[0]][0] + optimizer_init = self.config_init.get(optimizers[0], {}) + if not isinstance(optimizer_class, tuple): + optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) + lr_scheduler_init = None + if lr_schedulers: + lr_scheduler_class = self.parser.optimizers_and_lr_schedulers[lr_schedulers[0]][0] + lr_scheduler_init = self.config_init.get(lr_schedulers[0], {}) + if not isinstance(lr_scheduler_class, tuple): + lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) + + def configure_optimizers( + self: LightningModule, + ) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]: + optimizer = instantiate_class(self.parameters(), optimizer_init) + if not lr_scheduler_init: + return optimizer + lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) + return [optimizer], [lr_scheduler] + + self.model.configure_optimizers = MethodType(configure_optimizers, self.model) + + def prepare_fit_kwargs(self) -> None: + """Prepares fit_kwargs including datamodule using self.config_init['data'] if given.""" + self.fit_kwargs = {"model": self.model} + if self.datamodule is not None: + self.fit_kwargs["datamodule"] = self.datamodule + + def before_fit(self) -> None: + """Implement to run some code before fit is started.""" + + def fit(self) -> None: + """Runs fit of the instantiated trainer class and prepared fit keyword arguments.""" + self.trainer.fit(**self.fit_kwargs) + + def after_fit(self) -> None: + """Implement to run some code after fit has finished.""" + + +def _global_add_class_path(class_type: Type, init_args: Dict[str, Any]) -> Dict[str, Any]: + return { + "class_path": class_type.__module__ + "." + class_type.__name__, + "init_args": init_args, + } + + +def _add_class_path_generator(class_type: Type) -> Callable[[Dict[str, Any]], Dict[str, Any]]: + def add_class_path(init_args: Dict[str, Any]) -> Dict[str, Any]: + return _global_add_class_path(class_type, init_args) + + return add_class_path + + +def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: + """Instantiates a class with the given args and init. + + Args: + args: Positional arguments required for instantiation. + init: Dict of the form {"class_path":...,"init_args":...}. + + Returns: + The instantiated class object. + """ + kwargs = init.get("init_args", {}) + if not isinstance(args, tuple): + args = (args,) + class_module, class_name = init["class_path"].rsplit(".", 1) + module = __import__(class_module, fromlist=[class_name]) + args_class = getattr(module, class_name) + return args_class(*args, **kwargs) diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py new file mode 100644 index 0000000000..f25c402683 --- /dev/null +++ b/flash/core/utilities/providers.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +PROVIDERS = [] #: testing + + +@dataclass +class Provider: + + name: str + url: str + + def __post_init__(self): + PROVIDERS.append(self) + + def __str__(self): + return f"{self.name} ({self.url})" + + +_TIMM = Provider("rwightman/pytorch-image-models", "https://github.com/rwightman/pytorch-image-models") +_DINO = Provider("Facebook Research/dino", "https://github.com/facebookresearch/dino") +_ICEVISION = Provider("airctic/IceVision", "https://github.com/airctic/icevision") +_TORCHVISION = Provider("PyTorch/torchvision", "https://github.com/pytorch/vision") +_ULTRALYTICS = Provider("Ultralytics/YOLOV5", "https://github.com/ultralytics/yolov5") +_MMDET = Provider("OpenMMLab/MMDetection", "https://github.com/open-mmlab/mmdetection") +_EFFDET = Provider("rwightman/efficientdet-pytorch", "https://github.com/rwightman/efficientdet-pytorch") +_SEGMENTATION_MODELS = Provider( + "qubvel/segmentation_models.pytorch", "https://github.com/qubvel/segmentation_models.pytorch" +) +_PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche") +_HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers") +_FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq") +_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") +_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") diff --git a/flash/core/utilities/url_error.py b/flash/core/utilities/url_error.py new file mode 100644 index 0000000000..6f0d28676a --- /dev/null +++ b/flash/core/utilities/url_error.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import urllib.error + +from pytorch_lightning.utilities import rank_zero_warn + + +def catch_url_error(fn): + @functools.wraps(fn) + def wrapper(*args, pretrained=False, **kwargs): + try: + return fn(*args, pretrained=pretrained, **kwargs) + except urllib.error.URLError: + # Hack for icevision/efficientdet to work without internet access + if "efficientdet" in kwargs.get("head", ""): + kwargs["pretrained_backbone"] = False + result = fn(*args, pretrained=False, **kwargs) + rank_zero_warn( + "Failed to download pretrained weights for the selected backbone. The backbone has been created with" + " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" + " ignored.", + UserWarning, + ) + return result + + return wrapper diff --git a/flash/graph/__init__.py b/flash/graph/__init__.py new file mode 100644 index 0000000000..cb30102379 --- /dev/null +++ b/flash/graph/__init__.py @@ -0,0 +1 @@ +from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401 diff --git a/flash/graph/classification/__init__.py b/flash/graph/classification/__init__.py new file mode 100644 index 0000000000..f7a1b39194 --- /dev/null +++ b/flash/graph/classification/__init__.py @@ -0,0 +1,2 @@ +from flash.graph.classification.data import GraphClassificationData # noqa: F401 +from flash.graph.classification.model import GraphClassifier # noqa: F401 diff --git a/flash/graph/classification/cli.py b/flash/graph/classification/cli.py new file mode 100644 index 0000000000..f79af259d8 --- /dev/null +++ b/flash/graph/classification/cli.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.utilities.flash_cli import FlashCLI +from flash.graph import GraphClassificationData, GraphClassifier + +__all__ = ["graph_classification"] + + +def from_tu_dataset( + name: str = "KKI", + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> GraphClassificationData: + """Downloads and loads the TU Dataset.""" + from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE + + if _TORCH_GEOMETRIC_AVAILABLE: + from torch_geometric.datasets import TUDataset + else: + raise ModuleNotFoundError("Please, pip install -e '.[graph]'") + + dataset = TUDataset(root="data", name=name) + + return GraphClassificationData.from_datasets( + train_dataset=dataset, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def graph_classification(): + """Classify graphs.""" + cli = FlashCLI( + GraphClassifier, + GraphClassificationData, + default_datamodule_builder=from_tu_dataset, + default_arguments={ + "trainer.max_epochs": 3, + }, + finetune=False, + datamodule_attributes={"num_classes", "num_features"}, + ) + + cli.trainer.save_checkpoint("graph_classification.pt") + + +if __name__ == "__main__": + graph_classification() diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py new file mode 100644 index 0000000000..cd5e3568f8 --- /dev/null +++ b/flash/graph/classification/data.py @@ -0,0 +1,69 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional + +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras +from flash.graph.data import GraphDatasetDataSource + +if _GRAPH_AVAILABLE: + from torch_geometric.data.batch import Batch + from torch_geometric.transforms import NormalizeFeatures + + +class GraphClassificationPreprocess(Preprocess): + @requires_extras("graph") + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATASETS: GraphDatasetDataSource(), + }, + default_data_source=DefaultDataSources.DATASETS, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return self.transforms + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + @staticmethod + def default_transforms() -> Optional[Dict[str, Callable]]: + return {"pre_tensor_transform": NormalizeFeatures(), "collate": Batch.from_data_list} + + +class GraphClassificationData(DataModule): + """Data module for graph classification tasks.""" + + preprocess_cls = GraphClassificationPreprocess + + @property + def num_features(self): + n_cls_train = getattr(self.train_dataset, "num_features", None) + n_cls_val = getattr(self.val_dataset, "num_features", None) + n_cls_test = getattr(self.test_dataset, "num_features", None) + return n_cls_train or n_cls_val or n_cls_test diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py new file mode 100644 index 0000000000..d8878c73c3 --- /dev/null +++ b/flash/graph/classification/model.py @@ -0,0 +1,145 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, List, Mapping, Sequence, Type, Union + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn import Linear + +from flash.core.classification import ClassificationTask +from flash.core.utilities.imports import _GRAPH_AVAILABLE + +if _GRAPH_AVAILABLE: + from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing +else: + MessagePassing = None + GCNConv = None + + +class GraphBlock(nn.Module): + def __init__(self, nc_input, nc_output, conv_cls, act=nn.ReLU(), **conv_kwargs): + super().__init__() + self.conv = conv_cls(nc_input, nc_output, **conv_kwargs) + self.norm = BatchNorm(nc_output) + self.act = act + + def forward(self, x, edge_index, edge_weight): + x = self.conv(x, edge_index, edge_weight=edge_weight) + x = self.norm(x) + return self.act(x) + + +class BaseGraphModel(nn.Module): + def __init__( + self, + num_features: int, + hidden_channels: List[int], + num_classes: int, + conv_cls: Type[MessagePassing], + act=nn.ReLU(), + **conv_kwargs: Any + ): + super().__init__() + + self.blocks = nn.ModuleList() + hidden_channels = [num_features] + hidden_channels + + nc_output = num_features + + for idx in range(len(hidden_channels) - 1): + nc_input = hidden_channels[idx] + nc_output = hidden_channels[idx + 1] + graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs) + self.blocks.append(graph_block) + + self.lin = Linear(nc_output, num_classes) + + def forward(self, data): + x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr + # 1. Obtain node embeddings + for block in self.blocks: + x = block(x, edge_index, edge_weight) + + # 2. Readout layer + x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels] + + # 3. Apply a final classifier + x = F.dropout(x, p=0.5, training=self.training) + x = self.lin(x) + return x + + +class GraphClassifier(ClassificationTask): + """The ``GraphClassifier`` is a :class:`~flash.Task` for classifying graphs. For more details, see + :ref:`graph_classification`. + + Args: + num_features: Number of columns in table (not including target column). + num_classes: Number of classes to classify. + hidden_channels: Hidden dimension sizes. + loss_fn: Loss function for training, defaults to cross entropy. + optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. + metrics: Metrics to compute for training and evaluation. + learning_rate: Learning rate to use for training, defaults to `1e-3` + model: GraphNN used, defaults to BaseGraphModel. + conv_cls: kind of convolution used in model, defaults to GCNConv + """ + + required_extras = "graph" + + def __init__( + self, + num_features: int, + num_classes: int, + hidden_channels: Union[List[int], int] = 512, + loss_fn: Callable = F.cross_entropy, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + metrics: Union[Callable, Mapping, Sequence, None] = None, + learning_rate: float = 1e-3, + model: torch.nn.Module = None, + conv_cls: Type[MessagePassing] = GCNConv, + **conv_kwargs + ): + + self.save_hyperparameters() + + if isinstance(hidden_channels, int): + hidden_channels = [hidden_channels] + + if not model: + model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs) + + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + ) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch, batch.y) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch, batch.y) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch, batch.y) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) diff --git a/flash/graph/data.py b/flash/graph/data.py new file mode 100644 index 0000000000..a3d020bc36 --- /dev/null +++ b/flash/graph/data.py @@ -0,0 +1,39 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Mapping, Optional + +from torch.utils.data import Dataset + +from flash.core.data.data_source import DatasetDataSource +from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras + +if _GRAPH_AVAILABLE: + from torch_geometric.data import Data + from torch_geometric.data import Dataset as TorchGeometricDataset + + +class GraphDatasetDataSource(DatasetDataSource): + @requires_extras("graph") + def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: + data = super().load_data(data, dataset=dataset) + if not self.predicting: + if isinstance(data, TorchGeometricDataset): + dataset.num_classes = data.num_classes + dataset.num_features = data.num_features + return data + + def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: + if isinstance(sample, Data): + return sample + return super().load_sample(sample, dataset=dataset) diff --git a/flash/image/__init__.py b/flash/image/__init__.py index c099e1c086..b3ac7f10b6 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -1,11 +1,13 @@ -from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES # noqa: F401 from flash.image.classification import ( # noqa: F401 ImageClassificationData, ImageClassificationPreprocess, ImageClassifier, ) +from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401 from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401 from flash.image.embedding import ImageEmbedder # noqa: F401 +from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401 +from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401 from flash.image.segmentation import ( # noqa: F401 SemanticSegmentation, SemanticSegmentationData, diff --git a/flash/image/backbones.py b/flash/image/backbones.py deleted file mode 100644 index 103d3c37ee..0000000000 --- a/flash/image/backbones.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import os -import urllib.error -import warnings -from functools import partial -from typing import Tuple - -import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_warn -from torch import nn - -from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE - -if _TIMM_AVAILABLE: - import timm - -if _TORCHVISION_AVAILABLE: - import torchvision - from torchvision.models.detection.backbone_utils import resnet_fpn_backbone - -if _BOLTS_AVAILABLE: - if os.getenv("WARN_MISSING_PACKAGE") == "0": - with warnings.catch_warnings(record=True) as w: - from pl_bolts.models.self_supervised import SimCLR, SwAV - else: - from pl_bolts.models.self_supervised import SimCLR, SwAV - -ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" - -MOBILENET_MODELS = ["mobilenet_v2"] -VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"] -RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"] -DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"] -TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS -BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"] - -IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") -OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") - - -def catch_url_error(fn): - - @functools.wraps(fn) - def wrapper(*args, pretrained=False, **kwargs): - try: - return fn(*args, pretrained=pretrained, **kwargs) - except urllib.error.URLError: - result = fn(*args, pretrained=False, **kwargs) - rank_zero_warn( - "Failed to download pretrained weights for the selected backbone. The backbone has been created with" - " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" - " ignored.", UserWarning - ) - return result - - return wrapper - - -@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts") -def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_): - simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False) - # remove the last two layers & turn it into a Sequential model - backbone = nn.Sequential(*list(simclr.encoder.children())[:-2]) - return backbone, 2048 - - -@IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts") -def load_swav_imagenet( - path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar", - **_, -) -> Tuple[nn.Module, int]: - swav: LightningModule = SwAV.load_from_checkpoint(path_or_url, strict=True) - # remove the last two layers & turn it into a Sequential model - backbone = nn.Sequential(*list(swav.model.children())[:-2]) - return backbone, 2048 - - -if _TORCHVISION_AVAILABLE: - - def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) - backbone = model.features - num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features - return backbone, num_features - - for model_name in MOBILENET_MODELS + VGG_MODELS: - _type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg" - - IMAGE_CLASSIFIER_BACKBONES( - fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)), - name=model_name, - namespace="vision", - package="torchvision", - type=_type - ) - - def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) - backbone = nn.Sequential(*list(model.children())[:-2]) - num_features = model.fc.in_features - return backbone, num_features - - def _fn_resnet_fpn( - model_name: str, - pretrained: bool = True, - trainable_layers: bool = True, - **kwargs, - ) -> Tuple[nn.Module, int]: - backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs) - return backbone, 256 - - for model_name in RESNET_MODELS: - IMAGE_CLASSIFIER_BACKBONES( - fn=catch_url_error(partial(_fn_resnet, model_name)), - name=model_name, - namespace="vision", - package="torchvision", - type="resnet" - ) - - OBJ_DETECTION_BACKBONES( - fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), - name=model_name, - package="torchvision", - type="resnet-fpn" - ) - - def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) - backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) - num_features = model.classifier.in_features - return backbone, num_features - - for model_name in DENSENET_MODELS: - IMAGE_CLASSIFIER_BACKBONES( - fn=catch_url_error(partial(_fn_densenet, model_name)), - name=model_name, - namespace="vision", - package="torchvision", - type="densenet" - ) - -if _TIMM_AVAILABLE: - - def _fn_timm( - model_name: str, - pretrained: bool = True, - num_classes: int = 0, - **kwargs, - ) -> Tuple[nn.Module, int]: - backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) - num_features = backbone.num_features - return backbone, num_features - - for model_name in timm.list_models(): - - if model_name in TORCHVISION_MODELS: - continue - - IMAGE_CLASSIFIER_BACKBONES( - fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm" - ) - - -# Paper: Emerging Properties in Self-Supervised Vision Transformers -# https://arxiv.org/abs/2104.14294 from Mathilde Caron and al. (29 Apr 2021) -# weights from https://github.com/facebookresearch/dino -def dino_deits16(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits16') - return backbone, 384 - - -def dino_deits8(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits8') - return backbone, 384 - - -def dino_vitb16(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') - return backbone, 768 - - -def dino_vitb8(*_, **__): - backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8') - return backbone, 768 - - -IMAGE_CLASSIFIER_BACKBONES(dino_deits16) -IMAGE_CLASSIFIER_BACKBONES(dino_deits8) -IMAGE_CLASSIFIER_BACKBONES(dino_vitb16) -IMAGE_CLASSIFIER_BACKBONES(dino_vitb8) diff --git a/flash/image/classification/backbones/__init__.py b/flash/image/classification/backbones/__init__.py new file mode 100644 index 0000000000..db068b42b5 --- /dev/null +++ b/flash/image/classification/backbones/__init__.py @@ -0,0 +1,20 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.classification.backbones.resnet import register_resnet_backbones # noqa: F401 +from flash.image.classification.backbones.timm import register_timm_backbones # noqa: F401 +from flash.image.classification.backbones.torchvision import ( # noqa: F401 + register_densenet_backbones, + register_mobilenet_vgg_backbones, + register_resnext_model, +) +from flash.image.classification.backbones.transformers import register_dino_backbones # noqa: F401 + +IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") + +register_resnet_backbones(IMAGE_CLASSIFIER_BACKBONES) +register_dino_backbones(IMAGE_CLASSIFIER_BACKBONES) + +register_mobilenet_vgg_backbones(IMAGE_CLASSIFIER_BACKBONES) +register_resnext_model(IMAGE_CLASSIFIER_BACKBONES) +register_densenet_backbones(IMAGE_CLASSIFIER_BACKBONES) + +register_timm_backbones(IMAGE_CLASSIFIER_BACKBONES) diff --git a/flash/image/classification/backbones/resnet.py b/flash/image/classification/backbones/resnet.py new file mode 100644 index 0000000000..0f136e9df5 --- /dev/null +++ b/flash/image/classification/backbones/resnet.py @@ -0,0 +1,433 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# ResNet encoder adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py +# as the official torchvision implementation does not support wide resnet architecture +# found in self-supervised learning model weights +from functools import partial +from typing import Any, Callable, List, Optional, Type, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.hub import load_state_dict_from_url + +from flash.core.registry import FlashRegistry +from flash.core.utilities.url_error import catch_url_error + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution.""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + __constants__ = ["downsample"] + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + __constants__ = ["downsample"] + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + zero_init_residual: bool = False, + groups: int = 1, + widen: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + first_conv3x3: bool = False, + remove_first_maxpool: bool = False, + in_chans: int = 3, + ) -> None: + + super().__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = width_per_group * widen + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + num_out_filters = width_per_group * widen + + if first_conv3x3: + self.conv1 = nn.Conv2d(in_chans, num_out_filters, kernel_size=3, stride=1, padding=1, bias=False) + else: + self.conv1 = nn.Conv2d(in_chans, num_out_filters, kernel_size=7, stride=2, padding=3, bias=False) + + self.bn1 = norm_layer(num_out_filters) + self.relu = nn.ReLU(inplace=True) + + if remove_first_maxpool: + self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, num_out_filters, layers[0]) + num_out_filters *= 2 + self.layer2 = self._make_layer( + block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + num_out_filters *= 2 + self.layer3 = self._make_layer( + block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + num_out_filters *= 2 + self.layer4 = self._make_layer( + block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + return x + + +def _resnet( + model_name: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_features: int, + pretrained: Union[bool, str] = True, + weights_paths: dict = {"supervised": None}, + **kwargs: Any, +) -> ResNet: + + pretrained_flag = (pretrained and isinstance(pretrained, bool)) or (pretrained == "supervised") + + backbone = ResNet(block, layers, **kwargs) + device = next(backbone.parameters()).get_device() + + model_weights = None + if pretrained_flag: + if "supervised" not in weights_paths: + raise KeyError(f"Supervised pretrained weights not available for {model_name}") + + model_weights = load_state_dict_from_url( + weights_paths["supervised"], map_location=torch.device("cpu") if device == -1 else torch.device(device) + ) + + # for supervised pretrained weights + model_weights.pop("fc.weight") + model_weights.pop("fc.bias") + + if not pretrained_flag and isinstance(pretrained, str): + if pretrained in weights_paths: + model_weights = load_state_dict_from_url( + weights_paths[pretrained], map_location=torch.device("cpu") if device == -1 else torch.device(device) + ) + + if "classy_state_dict" in model_weights.keys(): + model_weights = model_weights["classy_state_dict"]["base_model"]["model"]["trunk"] + model_weights = { + key.replace("_feature_blocks.", "") if "_feature_blocks." in key else key: val + for (key, val) in model_weights.items() + } + else: + raise KeyError("Unrecognized state dict. Logic for loading the current state dict missing.") + else: + raise KeyError( + f"Requested weights for {model_name} not available," f" choose from one of {weights_paths.keys()}" + ) + + if model_weights is not None: + backbone.load_state_dict(model_weights) + + return backbone, num_features + + +HTTPS_VISSL = "https://dl.fbaipublicfiles.com/vissl/model_zoo/" +RESNET50_WEIGHTS_PATHS = { + "supervised": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "simclr": HTTPS_VISSL + "simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/" + "model_final_checkpoint_phase799.torch", + "swav": HTTPS_VISSL + "swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/" + "model_final_checkpoint_phase799.torch", +} +RESNET50W2_WEIGHTS_PATHS = { + "simclr": HTTPS_VISSL + "simclr_rn50w2_1000ep_simclr_8node_resnet_16_07_20.e1e3bbf0/" + "model_final_checkpoint_phase999.torch", + "swav": HTTPS_VISSL + "swav_rn50w2_in1k_bs32_16node_400ep_swav_8node_resnet_30_07_20.93563e51/" + "model_final_checkpoint_phase399.torch", +} +RESNET50W4_WEIGHTS_PATHS = { + "simclr": HTTPS_VISSL + "simclr_rn50w4_1000ep_bs32_16node_simclr_8node_resnet_28_07_20.9e20b0ae/" + "model_final_checkpoint_phase999.torch", + "swav": HTTPS_VISSL + "swav_rn50w4_in1k_bs40_8node_400ep_swav_8node_resnet_30_07_20.1736135b/" + "model_final_checkpoint_phase399.torch", +} + +RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet50w2", "resnet50w4"] +RESNET_PARAMS = [ + { + "block": BasicBlock, + "layers": [2, 2, 2, 2], + "num_features": 512, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet18-f37072fd.pth"}, + }, + { + "block": BasicBlock, + "layers": [3, 4, 6, 3], + "num_features": 512, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet34-b627a593.pth"}, + }, + {"block": Bottleneck, "layers": [3, 4, 6, 3], "num_features": 2048, "weights_paths": RESNET50_WEIGHTS_PATHS}, + { + "block": Bottleneck, + "layers": [3, 4, 23, 3], + "num_features": 2048, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet101-63fe2227.pth"}, + }, + { + "block": Bottleneck, + "layers": [3, 8, 36, 3], + "num_features": 2048, + "weights_paths": {"supervised": "https://download.pytorch.org/models/resnet152-394f9c45.pth"}, + }, + { + "block": Bottleneck, + "layers": [3, 4, 6, 3], + "widen": 2, + "num_features": 4096, + "weights_paths": RESNET50W2_WEIGHTS_PATHS, + }, + { + "block": Bottleneck, + "layers": [3, 4, 6, 3], + "widen": 4, + "num_features": 8192, + "weights_paths": RESNET50W4_WEIGHTS_PATHS, + }, +] + + +def register_resnet_backbones(register: FlashRegistry): + for model_name, params in zip(RESNET_MODELS, RESNET_PARAMS): + register( + fn=catch_url_error(partial(_resnet, model_name=model_name, **params)), + name=model_name, + namespace="vision", + package="multiple", + type="resnet", + weights_paths=params["weights_paths"], # update + ) diff --git a/flash/image/classification/backbones/timm.py b/flash/image/classification/backbones/timm.py new file mode 100644 index 0000000000..ffdc71c39a --- /dev/null +++ b/flash/image/classification/backbones/timm.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Tuple + +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TIMM_AVAILABLE +from flash.core.utilities.providers import _TIMM +from flash.core.utilities.url_error import catch_url_error +from flash.image.classification.backbones.torchvision import TORCHVISION_MODELS + +if _TIMM_AVAILABLE: + import timm + + def _fn_timm( + model_name: str, + pretrained: bool = True, + num_classes: int = 0, + **kwargs, + ) -> Tuple[nn.Module, int]: + backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) + num_features = backbone.num_features + return backbone, num_features + + +def register_timm_backbones(register: FlashRegistry): + if _TIMM_AVAILABLE: + for model_name in timm.list_models(): + + if model_name in TORCHVISION_MODELS: + continue + + register( + fn=catch_url_error(partial(_fn_timm, model_name)), + name=model_name, + namespace="vision", + package="timm", + providers=_TIMM, + ) diff --git a/flash/image/classification/backbones/torchvision.py b/flash/image/classification/backbones/torchvision.py new file mode 100644 index 0000000000..11c59792d3 --- /dev/null +++ b/flash/image/classification/backbones/torchvision.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Tuple + +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _TORCHVISION +from flash.core.utilities.url_error import catch_url_error +from flash.image.classification.backbones.resnet import RESNET_MODELS + +MOBILENET_MODELS = ["mobilenet_v2"] +VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"] +RESNEXT_MODELS = ["resnext50_32x4d", "resnext101_32x8d"] +DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"] +TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNEXT_MODELS + RESNET_MODELS + DENSENET_MODELS + +if _TORCHVISION_AVAILABLE: + import torchvision + + def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + backbone = model.features + num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features + return backbone, num_features + + def _fn_resnext(model_name: str, pretrained: bool = True): + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + backbone = nn.Sequential(*list(model.children())[:-2]) + num_features = model.fc.in_features + + return backbone, num_features + + def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) + num_features = model.classifier.in_features + return backbone, num_features + + +def register_mobilenet_vgg_backbones(register: FlashRegistry): + if _TORCHVISION_AVAILABLE: + for model_name in MOBILENET_MODELS + VGG_MODELS: + _type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg" + + register( + fn=catch_url_error(partial(_fn_mobilenet_vgg, model_name)), + name=model_name, + namespace="vision", + type=_type, + providers=_TORCHVISION, + ) + + +def register_resnext_model(register: FlashRegistry): + if _TORCHVISION_AVAILABLE: + for model_name in RESNEXT_MODELS: + register( + fn=catch_url_error(partial(_fn_resnext, model_name)), + name=model_name, + namespace="vision", + type="resnext", + providers=_TORCHVISION, + ) + + +def register_densenet_backbones(register: FlashRegistry): + if _TORCHVISION_AVAILABLE: + for model_name in DENSENET_MODELS: + register( + fn=catch_url_error(partial(_fn_densenet, model_name)), + name=model_name, + namespace="vision", + type="densenet", + providers=_TORCHVISION, + ) diff --git a/flash/image/classification/backbones/transformers.py b/flash/image/classification/backbones/transformers.py new file mode 100644 index 0000000000..cf1fd1637c --- /dev/null +++ b/flash/image/classification/backbones/transformers.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from flash.core.registry import FlashRegistry +from flash.core.utilities.providers import _DINO +from flash.core.utilities.url_error import catch_url_error + + +# Paper: Emerging Properties in Self-Supervised Vision Transformers +# https://arxiv.org/abs/2104.14294 from Mathilde Caron and al. (29 Apr 2021) +# weights from https://github.com/facebookresearch/dino +def dino_deits16(*_, **__): + backbone = torch.hub.load("facebookresearch/dino:main", "dino_deits16") + return backbone, 384 + + +def dino_deits8(*_, **__): + backbone = torch.hub.load("facebookresearch/dino:main", "dino_deits8") + return backbone, 384 + + +def dino_vitb16(*_, **__): + backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16") + return backbone, 768 + + +def dino_vitb8(*_, **__): + backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb8") + return backbone, 768 + + +def register_dino_backbones(register: FlashRegistry): + for model in (dino_deits16, dino_deits8, dino_vitb16, dino_vitb8): + register(catch_url_error(model), providers=_DINO) diff --git a/flash/image/classification/cli.py b/flash/image/classification/cli.py new file mode 100644 index 0000000000..6804c909f8 --- /dev/null +++ b/flash/image/classification/cli.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.image import ImageClassificationData, ImageClassifier + +__all__ = ["image_classification"] + + +def from_hymenoptera( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> ImageClassificationData: + """Downloads and loads the Hymenoptera (Ants, Bees) data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") + return ImageClassificationData.from_folders( + train_folder="data/hymenoptera_data/train/", + val_folder="data/hymenoptera_data/val/", + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def from_movie_posters( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> ImageClassificationData: + """Downloads and loads the movie posters genre classification data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "./data") + return ImageClassificationData.from_csv( + "Id", + ["Action", "Romance", "Crime", "Thriller", "Adventure"], + train_file="data/movie_posters/train/metadata.csv", + val_file="data/movie_posters/val/metadata.csv", + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def image_classification(): + """Classify images.""" + cli = FlashCLI( + ImageClassifier, + ImageClassificationData, + default_datamodule_builder=from_hymenoptera, + additional_datamodule_builders=[from_movie_posters], + default_arguments={ + "trainer.max_epochs": 3, + }, + datamodule_attributes={"num_classes", "multi_label"}, + ) + + cli.trainer.save_checkpoint("image_classification_model.pt") + + +if __name__ == "__main__": + image_classification() diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index d579c5a1ea..f83185e3f2 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -11,43 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import numpy as np +import pandas as pd import torch from pytorch_lightning.trainer.states import RunningStage +from torch.utils.data.sampler import Sampler from flash.core.data.base_viz import BaseVisualization # for viz from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources +from flash.core.data.data_source import ( + DefaultDataKeys, + DefaultDataSources, + LabelStudioImageDataSource, + LoaderDataFrameDataSource, +) from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _requires_extras +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, Image, requires, requires_extras from flash.image.classification.transforms import default_transforms, train_default_transforms from flash.image.data import ( + image_loader, ImageDeserializer, ImageFiftyOneDataSource, ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource, ) -from flash.core.data.data_source import LabelStudioImageDataSource if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: plt = None -if _PIL_AVAILABLE: - from PIL import Image -else: - class Image: - Image = None +class ImageClassificationDataFrameDataSource(LoaderDataFrameDataSource): + @requires_extras("image") + def __init__(self): + super().__init__(image_loader) + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + sample = super().load_sample(sample, dataset) + w, h = sample[DefaultDataKeys.INPUT].size # WxH + sample[DefaultDataKeys.METADATA]["size"] = (h, w) + return sample -class ImageClassificationPreprocess(Preprocess): +class ImageClassificationPreprocess(Preprocess): def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -73,7 +84,7 @@ def __init__( DefaultDataSources.TENSORS: ImageTensorDataSource(), "data_frame": ImageClassificationDataFrameDataSource(), DefaultDataSources.CSV: ImageClassificationDataFrameDataSource(), - DefaultDataSources.LABELSTUDIO: LabelStudioImageDataSource(**data_source_kwargs) + DefaultDataSources.LABELSTUDIO: LabelStudioImageDataSource(**data_source_kwargs), }, deserializer=deserializer or ImageDeserializer(), default_data_source=DefaultDataSources.FILES, @@ -98,6 +109,220 @@ class ImageClassificationData(DataModule): preprocess_cls = ImageClassificationPreprocess + @classmethod + def from_data_frame( + cls, + input_field: str, + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_data_frame: Optional[pd.DataFrame] = None, + train_images_root: Optional[str] = None, + train_resolver: Optional[Callable[[str, str], str]] = None, + val_data_frame: Optional[pd.DataFrame] = None, + val_images_root: Optional[str] = None, + val_resolver: Optional[Callable[[str, str], str]] = None, + test_data_frame: Optional[pd.DataFrame] = None, + test_images_root: Optional[str] = None, + test_resolver: Optional[Callable[[str, str], str]] = None, + predict_data_frame: Optional[pd.DataFrame] = None, + predict_images_root: Optional[str] = None, + predict_resolver: Optional[Callable[[str, str], str]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + sampler: Optional[Type[Sampler]] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given pandas + ``DataFrame`` objects. + + Args: + input_field: The field (column) in the pandas ``DataFrame`` to use for the input. + target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + train_images_root: The directory containing the train images. If ``None``, values in the ``input_field`` + will be assumed to be the full file paths. + train_resolver: The function to use to resolve filenames given the ``train_images_root`` and IDs from the + ``input_field`` column. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + val_images_root: The directory containing the validation images. If ``None``, the directory containing the + ``val_file`` will be used. + val_resolver: The function to use to resolve filenames given the ``val_images_root`` and IDs from the + ``input_field`` column. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + test_images_root: The directory containing the test images. If ``None``, the directory containing the + ``test_file`` will be used. + test_resolver: The function to use to resolve filenames given the ``test_images_root`` and IDs from the + ``input_field`` column. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + predict_images_root: The directory containing the predict images. If ``None``, the directory containing the + ``predict_file`` will be used. + predict_resolver: The function to use to resolve filenames given the ``predict_images_root`` and IDs from + the ``input_field`` column. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ImageClassificationData.from_data_frame( + "image_id", + "target", + train_data_frame=train_data, + train_images_root="data/train_images", + ) + """ + return cls.from_data_source( + "data_frame", + (train_data_frame, input_field, target_fields, train_images_root, train_resolver), + (val_data_frame, input_field, target_fields, val_images_root, val_resolver), + (test_data_frame, input_field, target_fields, test_images_root, test_resolver), + (predict_data_frame, input_field, target_fields, predict_images_root, predict_resolver), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) + + @classmethod + def from_csv( + cls, + input_field: str, + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_file: Optional[str] = None, + train_images_root: Optional[str] = None, + train_resolver: Optional[Callable[[str, str], str]] = None, + val_file: Optional[str] = None, + val_images_root: Optional[str] = None, + val_resolver: Optional[Callable[[str, str], str]] = None, + test_file: Optional[str] = None, + test_images_root: Optional[str] = None, + test_resolver: Optional[Callable[[str, str], str]] = None, + predict_file: Optional[str] = None, + predict_images_root: Optional[str] = None, + predict_resolver: Optional[Callable[[str, str], str]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + sampler: Optional[Type[Sampler]] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV + files using the :class:`~flash.core.data.data_source.DataSource` of name + :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` from the passed or constructed + :class:`~flash.core.data.process.Preprocess`. + + Args: + input_field: The field (column) in the CSV file to use for the input. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_file: The CSV file containing the training data. + train_images_root: The directory containing the train images. If ``None``, the directory containing the + ``train_file`` will be used. + train_resolver: The function to use to resolve filenames given the ``train_images_root`` and IDs from the + ``input_field`` column. + val_file: The CSV file containing the validation data. + val_images_root: The directory containing the validation images. If ``None``, the directory containing the + ``val_file`` will be used. + val_resolver: The function to use to resolve filenames given the ``val_images_root`` and IDs from the + ``input_field`` column. + test_file: The CSV file containing the testing data. + test_images_root: The directory containing the test images. If ``None``, the directory containing the + ``test_file`` will be used. + test_resolver: The function to use to resolve filenames given the ``test_images_root`` and IDs from the + ``input_field`` column. + predict_file: The CSV file containing the data to use when predicting. + predict_images_root: The directory containing the predict images. If ``None``, the directory containing the + ``predict_file`` will be used. + predict_resolver: The function to use to resolve filenames given the ``predict_images_root`` and IDs from + the ``input_field`` column. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ImageClassificationData.from_csv( + "image_id", + "target", + train_file="train_data.csv", + train_images_root="data/train_images", + ) + """ + return cls.from_data_source( + DefaultDataSources.CSV, + (train_file, input_field, target_fields, train_images_root, train_resolver), + (val_file, input_field, target_fields, val_images_root, val_resolver), + (test_file, input_field, target_fields, test_images_root, test_resolver), + (predict_file, input_field, target_fields, predict_images_root, predict_resolver), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) + def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value @@ -108,16 +333,18 @@ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: class MatplotlibVisualization(BaseVisualization): - """Process and show the image batch and its associated label using matplotlib. - """ + """Process and show the image batch and its associated label using matplotlib.""" + max_cols: int = 4 # maximum number of columns we accept block_viz_window: bool = True # parameter to allow user to block visualisation windows @staticmethod - @_requires_extras("image") - def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: + @requires_extras("image") + def _to_numpy(img: Union[np.ndarray, torch.Tensor, Image.Image]) -> np.ndarray: out: np.ndarray - if isinstance(img, Image.Image): + if isinstance(img, np.ndarray): + out = img + elif isinstance(img, Image.Image): out = np.array(img) elif isinstance(img, torch.Tensor): out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() @@ -125,7 +352,7 @@ def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: raise TypeError(f"Unknown image type. Got: {type(img)}.") return out - @_requires_extras("image") + @requires("matplotlib") def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): # define the image grid cols: int = min(num_samples, self.max_cols) @@ -135,7 +362,10 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) fig, axs = plt.subplots(rows, cols) fig.suptitle(title) - for i, ax in enumerate(axs.ravel()): + if not isinstance(axs, np.ndarray): + axs = [axs] + + for i, ax in enumerate(axs): # unpack images and labels if isinstance(data, list): _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") @@ -150,7 +380,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) # show image and set label as subplot title ax.imshow(_img) ax.set_title(str(_label)) - ax.axis('off') + ax.axis("off") plt.show(block=self.block_viz_window) def show_load_sample(self, samples: List[Any], running_stage: RunningStage): diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 46c1f6cbd2..ba70b6988c 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -17,13 +17,13 @@ import torch from torch import nn from torch.optim.lr_scheduler import _LRScheduler -from torchmetrics import Accuracy, F1, Metric +from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry -from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES +from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES class ImageClassifier(ClassificationTask): @@ -51,11 +51,12 @@ def fn_resnet(pretrained: bool = True): Args: num_classes: Number of classes to classify. backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``. - pretrained: Use a pretrained backbone, defaults to ``True``. + pretrained: A bool or string to specify the pretrained weights of the backbone, defaults to ``True`` + which loads the default supervised pretrained weights. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. optimizer: Optimizer to use for training, defaults to :class:`torch.optim.SGD`. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` - package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict + package, a custom metric inheriting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. @@ -73,7 +74,7 @@ def __init__( backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, - pretrained: bool = True, + pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -85,16 +86,17 @@ def __init__( serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): super().__init__( + num_classes=num_classes, model=None, loss_fn=loss_fn, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=metrics or F1(num_classes) if multi_label else Accuracy(), + metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer or Labels(), + serializer=serializer or Labels(multi_label=multi_label), ) self.save_hyperparameters() @@ -108,7 +110,9 @@ def __init__( self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) + self.head = head or nn.Sequential( + nn.Linear(num_features, num_classes), + ) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -123,9 +127,9 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch[DefaultDataKeys.PREDS] = super().predict_step((batch[DefaultDataKeys.INPUT]), - batch_idx, - dataloader_idx=dataloader_idx) + batch[DefaultDataKeys.PREDS] = super().predict_step( + (batch[DefaultDataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + ) return batch def forward(self, x) -> torch.Tensor: @@ -134,11 +138,19 @@ def forward(self, x) -> torch.Tensor: x = x.mean(-1).mean(-1) return self.head(x) + @classmethod + def available_pretrained_weights(cls, backbone: str): + result = cls.backbones.get(backbone, with_metadata=True) + pretrained_weights = None + + if "weights_paths" in result["metadata"]: + pretrained_weights = list(result["metadata"]["weights_paths"].keys()) + + return pretrained_weights + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" if self.hparams.multi_label: - assert history[-1]["val_f1"] > 0.45 + assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"] else: - assert history[-1]["val_accuracy"] > 0.90 + assert history[-1]["val_accuracy"] > 0.85, history[-1]["val_accuracy"] diff --git a/flash/image/classification/transforms.py b/flash/image/classification/transforms.py index 945f1cabc5..3b5ba98a4c 100644 --- a/flash/image/classification/transforms.py +++ b/flash/image/classification/transforms.py @@ -47,7 +47,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: "per_batch_transform_on_device": ApplyToKeys( DefaultDataKeys.INPUT, K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ) + ), } return { "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)), diff --git a/flash/image/data.py b/flash/image/data.py index 015ee19caf..45d7f2af6c 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -16,37 +16,47 @@ from pathlib import Path from typing import Any, Dict, Optional +import numpy as np import torch import flash from flash.core.data.data_source import ( DefaultDataKeys, FiftyOneDataSource, + has_file_allowed_extension, NumpyDataSource, PathsDataSource, TensorDataSource, ) from flash.core.data.process import Deserializer -from flash.core.utilities.imports import _PIL_AVAILABLE, _requires_extras, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires_extras if _TORCHVISION_AVAILABLE: import torchvision from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image else: - IMG_EXTENSIONS = [] + IMG_EXTENSIONS = () -if _PIL_AVAILABLE: - from PIL import Image as PILImage -else: - class Image: - Image = None +NP_EXTENSIONS = (".npy", ".npz") -class ImageDeserializer(Deserializer): +def image_loader(filepath: str): + if has_file_allowed_extension(filepath, IMG_EXTENSIONS): + img = default_loader(filepath) + elif has_file_allowed_extension(filepath, NP_EXTENSIONS): + img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB") + else: + raise ValueError( + f"File: {filepath} has an unsupported extension. Supported extensions: " + f"{list(IMG_EXTENSIONS + NP_EXTENSIONS)}." + ) + return img - @_requires_extras("image") + +class ImageDeserializer(Deserializer): + @requires_extras("image") def __init__(self): super().__init__() self.to_tensor = torchvision.transforms.ToTensor() @@ -55,7 +65,7 @@ def deserialize(self, data: str) -> Dict: encoded_with_padding = (data + "===").encode("ascii") img = base64.b64decode(encoded_with_padding) buffer = BytesIO(img) - img = PILImage.open(buffer, mode="r") + img = Image.open(buffer, mode="r") return { DefaultDataKeys.INPUT: img, } @@ -67,25 +77,18 @@ def example_input(self) -> str: class ImagePathsDataSource(PathsDataSource): - - @_requires_extras("image") + @requires_extras("image") def __init__(self): - super().__init__(extensions=IMG_EXTENSIONS) + super().__init__(loader=image_loader, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - img_path = sample[DefaultDataKeys.INPUT] - img = default_loader(img_path) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": img_path, - "size": (h, w), - } + sample = super().load_sample(sample, dataset) + w, h = sample[DefaultDataKeys.INPUT].size # WxH + sample[DefaultDataKeys.METADATA]["size"] = (h, w) return sample class ImageTensorDataSource(TensorDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = to_pil_image(sample[DefaultDataKeys.INPUT]) sample[DefaultDataKeys.INPUT] = img @@ -95,7 +98,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class ImageNumpyDataSource(NumpyDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = to_pil_image(torch.from_numpy(sample[DefaultDataKeys.INPUT])) sample[DefaultDataKeys.INPUT] = img @@ -105,7 +107,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class ImageFiftyOneDataSource(FiftyOneDataSource): - @staticmethod def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img_path = sample[DefaultDataKeys.INPUT] diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py new file mode 100644 index 0000000000..c3e9d5cfad --- /dev/null +++ b/flash/image/detection/backbones.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +import torch + +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, + load_icevision_with_image_size, +) +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric + +OBJECT_DETECTION_HEADS = FlashRegistry("heads") + + +class IceVisionObjectDetectionAdapter(IceVisionAdapter): + @classmethod + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str = "resnet18_fpn", + head: str = "retinanet", + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.bbox)], + image_size=image_size, + **kwargs, + ) + + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + for model_type in [icevision_models.torchvision.retinanet, icevision_models.torchvision.faster_rcnn]: + OBJECT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) + + if _module_available("yolov5"): + model_type = icevision_models.ultralytics.yolov5 + OBJECT_DETECTION_HEADS( + partial(load_icevision_with_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _ULTRALYTICS], + ) + + if _module_available("mmdet"): + for model_type in [ + icevision_models.mmdet.faster_rcnn, + icevision_models.mmdet.retinanet, + icevision_models.mmdet.fcos, + icevision_models.mmdet.sparse_rcnn, + ]: + OBJECT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + f"mmdet_{model_type.__name__.split('.')[-1]}", + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _MMDET], + ) + + if _module_available("effdet"): + + def _icevision_effdet_model_adapter(model_type): + class IceVisionEffdetModelAdapter(icevision_model_adapter(model_type)): + def validation_step(self, batch, batch_idx): + images = batch[0][0] + batch[0][1]["img_scale"] = torch.ones_like(images[:, 0, 0, 0]).unsqueeze(1) + batch[0][1]["img_size"] = ( + (torch.ones_like(images[:, 0, 0, 0]) * images[0].shape[-1]).unsqueeze(1).repeat(1, 2) + ) + return super().validation_step(batch, batch_idx) + + return IceVisionEffdetModelAdapter + + model_type = icevision_models.ross.efficientdet + OBJECT_DETECTION_HEADS( + partial(load_icevision_with_image_size, _icevision_effdet_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionObjectDetectionAdapter, + providers=[_ICEVISION, _EFFDET], + ) diff --git a/flash/image/detection/cli.py b/flash/image/detection/cli.py new file mode 100644 index 0000000000..8c2eb0c3d1 --- /dev/null +++ b/flash/image/detection/cli.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.image import ObjectDetectionData, ObjectDetector + +__all__ = ["object_detection"] + + +def from_coco_128( + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> ObjectDetectionData: + """Downloads and loads the COCO 128 data set.""" + download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + return ObjectDetectionData.from_coco( + train_folder="data/coco128/images/train2017/", + train_ann_file="data/coco128/annotations/instances_train2017.json", + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def object_detection(): + """Detect objects in images.""" + cli = FlashCLI( + ObjectDetector, + ObjectDetectionData, + default_datamodule_builder=from_coco_128, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("object_detection_model.pt") + + +if __name__ == "__main__": + object_detection() diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index da660591d3..d75ff23430 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, TYPE_CHECKING from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, FiftyOneDataSource +from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource from flash.core.data.process import Preprocess -from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import -from flash.image.data import ImagePathsDataSource -from flash.image.detection.transforms import default_transforms - -if _COCO_AVAILABLE: - from pycocotools.coco import COCO +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires SampleCollection = None if _FIFTYONE_AVAILABLE: @@ -33,187 +33,138 @@ else: foc, fol = None, None -if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader - - -class COCODataSource(DataSource[Tuple[str, str]]): +if _ICEVISION_AVAILABLE: + from icevision.core import BBox, ClassMap, IsCrowdsRecordComponent, ObjectDetectionRecord + from icevision.data import SingleSplitSplitter + from icevision.parsers import COCOBBoxParser, Parser, VIABBoxParser, VOCBBoxParser + from icevision.utils import ImgSize +else: + Parser = object - def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: - root, ann_file = data - coco = COCO(ann_file) +class FiftyOneParser(Parser): + def __init__(self, data, class_map, label_field, iscrowd): + template_record = ObjectDetectionRecord() + template_record.add_component(IsCrowdsRecordComponent()) + super().__init__(template_record=template_record) - categories = coco.loadCats(coco.getCatIds()) - if categories: - dataset.num_classes = categories[-1]["id"] + 1 + data = data + label_field = label_field + iscrowd = iscrowd - img_ids = list(sorted(coco.imgs.keys())) - paths = coco.loadImgs(img_ids) + self.data = [] + self.class_map = class_map - data = [] + for fp, w, h, sample_labs, sample_boxes, sample_iscrowd in zip( + data.values("filepath"), + data.values("metadata.width"), + data.values("metadata.height"), + data.values(label_field + ".detections.label"), + data.values(label_field + ".detections.bounding_box"), + data.values(label_field + ".detections." + iscrowd), + ): + for lab, box, iscrowd in zip(sample_labs, sample_boxes, sample_iscrowd): + self.data.append((fp, w, h, lab, box, iscrowd)) - for img_id, path in zip(img_ids, paths): - path = path["file_name"] + def __iter__(self) -> Any: + return iter(self.data) - ann_ids = coco.getAnnIds(imgIds=img_id) - annotations = coco.loadAnns(ann_ids) + def __len__(self) -> int: + return len(self.data) - boxes, labels, areas, iscrowd = [], [], [], [] + def record_id(self, o) -> Hashable: + return o[0] - # Ref: https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py - if self.training and all(any(o <= 1 for o in obj["bbox"][2:]) for obj in annotations): - continue + def parse_fields(self, o, record, is_new): + fp, w, h, lab, box, iscrowd = o - for obj in annotations: - xmin = obj["bbox"][0] - ymin = obj["bbox"][1] - xmax = xmin + obj["bbox"][2] - ymax = ymin + obj["bbox"][3] + if iscrowd is None: + iscrowd = 0 - bbox = [xmin, ymin, xmax, ymax] - keep = (bbox[3] > bbox[1]) & (bbox[2] > bbox[0]) - if keep: - boxes.append(bbox) - labels.append(obj["category_id"]) - areas.append(obj["area"]) - iscrowd.append(obj["iscrowd"]) + if is_new: + record.set_filepath(fp) + record.set_img_size(ImgSize(width=w, height=h)) + record.detection.set_class_map(self.class_map) - data.append( - dict( - input=os.path.join(root, path), - target=dict( - boxes=boxes, - labels=labels, - image_id=img_id, - area=areas, - iscrowd=iscrowd, - ) - ) - ) - return data + box = self._reformat_bbox(*box, w, h) - def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - return sample - return sample + record.detection.add_bboxes([BBox.from_xyxy(*box)]) + record.detection.add_labels([lab]) + record.detection.add_iscrowds([iscrowd]) + @staticmethod + def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): + xmin *= img_w + ymin *= img_h + box_w *= img_w + box_h *= img_h + xmax = xmin + box_w + ymax = ymin + box_h + output_bbox = [xmin, ymin, xmax, ymax] + return output_bbox -class ObjectDetectionFiftyOneDataSource(FiftyOneDataSource): +class ObjectDetectionFiftyOneDataSource(IceVisionPathsDataSource, FiftyOneDataSource): def __init__(self, label_field: str = "ground_truth", iscrowd: str = "iscrowd"): - super().__init__(label_field=label_field) + super().__init__() + self.label_field = label_field self.iscrowd = iscrowd @property + @requires("fiftyone") def label_cls(self): return fol.Detections + @requires("fiftyone") def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: self._validate(data) data.compute_metadata() - - filepaths = data.values("filepath") - widths = data.values("metadata.width") - heights = data.values("metadata.height") - labels = data.values(self.label_field + ".detections.label") - bboxes = data.values(self.label_field + ".detections.bounding_box") - iscrowds = data.values(self.label_field + ".detections." + self.iscrowd) - classes = self._get_classes(data) - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - if dataset is not None: - dataset.num_classes = len(classes) - - output_data = [] - img_id = 1 - for fp, w, h, sample_labs, sample_boxes, sample_iscrowd in zip( - filepaths, widths, heights, labels, bboxes, iscrowds - ): - output_boxes = [] - output_labs = [] - output_iscrowd = [] - output_areas = [] - for lab, box, iscrowd in zip(sample_labs, sample_boxes, sample_iscrowd): - output_box, output_area = self._reformat_bbox(box[0], box[1], box[2], box[3], w, h) - output_areas.append(output_area) - output_labs.append(class_to_idx[lab]) - output_boxes.append(output_box) - if iscrowd is None: - iscrowd = 0 - output_iscrowd.append(iscrowd) - output_data.append( - dict( - input=fp, - target=dict( - boxes=output_boxes, - labels=output_labs, - image_id=img_id, - area=output_areas, - iscrowd=output_iscrowd, - ) - ) - ) - img_id += 1 - - return output_data + class_map = ClassMap(classes) + dataset.num_classes = len(class_map) - @staticmethod - def load_sample(sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - filepath = sample[DefaultDataKeys.INPUT] - img = default_loader(filepath) - sample[DefaultDataKeys.INPUT] = img - w, h = img.size # WxH - sample[DefaultDataKeys.METADATA] = { - "filepath": filepath, - "size": (h, w), - } - return sample + parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) + records = parser.parse(data_splitter=SingleSplitSplitter()) + return [{DefaultDataKeys.INPUT: record} for record in records[0]] @staticmethod - def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): - xmin *= img_w - ymin *= img_h - box_w *= img_w - box_h *= img_h - xmax = xmin + box_w - ymax = ymin + box_h - output_bbox = [xmin, ymin, xmax, ymax] - return output_bbox, box_w * box_h + @requires("fiftyone") + def predict_load_data(data: SampleCollection, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: + return [{DefaultDataKeys.INPUT: f} for f in data.values("filepath")] class ObjectDetectionPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, **data_source_kwargs: Any, ): + self.image_size = image_size + super().__init__( train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, data_sources={ + "coco": IceVisionParserDataSource(parser=COCOBBoxParser), + "via": IceVisionParserDataSource(parser=VIABBoxParser), + "voc": IceVisionParserDataSource(parser=VOCBBoxParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), DefaultDataSources.FIFTYONE: ObjectDetectionFiftyOneDataSource(**data_source_kwargs), - DefaultDataSources.FILES: ImagePathsDataSource(), - DefaultDataSources.FOLDERS: ImagePathsDataSource(), - "coco": COCODataSource(), }, default_data_source=DefaultDataSources.FILES, ) + self._default_collate = self._identity + def get_state_dict(self) -> Dict[str, Any]: return {**self.transforms} @@ -222,7 +173,10 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) def default_transforms(self) -> Optional[Dict[str, Callable]]: - return default_transforms() + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) class ObjectDetectionData(DataModule): @@ -238,9 +192,11 @@ def from_coco( val_ann_file: Optional[str] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, data_fetcher: Optional[BaseDataFetcher] = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, @@ -248,8 +204,8 @@ def from_coco( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data - folders and corresponding target folders. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the COCO format. Args: train_folder: The folder containing the train data. @@ -258,12 +214,15 @@ def from_coco( val_ann_file: The COCO format annotation file. test_folder: The folder containing the test data. test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. train_transform: The dictionary of transforms to use during training which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. val_transform: The dictionary of transforms to use during validation which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. test_transform: The dictionary of transforms to use during testing which maps :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the :class:`~flash.core.data.data_module.DataModule`. preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the @@ -280,7 +239,7 @@ def from_coco( Examples:: - data_module = SemanticSegmentationData.from_coco( + data_module = ObjectDetectionData.from_coco( train_folder="train_folder", train_ann_file="annotations.json", ) @@ -290,9 +249,169 @@ def from_coco( (train_folder, train_ann_file) if train_folder else None, (val_folder, val_ann_file) if val_folder else None, (test_folder, test_ann_file) if test_folder else None, + predict_folder, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_voc( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the VOC format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ObjectDetectionData.from_voc( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "voc", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_via( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders + and annotation files in the VIA format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ObjectDetectionData.from_via( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "via", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, data_fetcher=data_fetcher, preprocess=preprocess, val_split=val_split, diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index 41edea48ee..c2bcd606f6 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -11,55 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Mapping, Optional, Type, Union import torch -from torch import nn, tensor from torch.optim import Optimizer -from flash.core.data.data_source import DefaultDataKeys +from flash.core.adapter import AdapterTask from flash.core.data.process import Serializer -from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.image.backbones import OBJ_DETECTION_BACKBONES -from flash.image.detection.finetuning import ObjectDetectionFineTuning -from flash.image.detection.serialization import DetectionLabels +from flash.image.detection.backbones import OBJECT_DETECTION_HEADS -if _TORCHVISION_AVAILABLE: - import torchvision - from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor - from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead - from torchvision.models.detection.rpn import AnchorGenerator - from torchvision.ops import box_iou - _models = { - "fasterrcnn": torchvision.models.detection.fasterrcnn_resnet50_fpn, - "retinanet": torchvision.models.detection.retinanet_resnet50_fpn, - } - -else: - AnchorGenerator = None - - -def _evaluate_iou(target, pred): - """ - Evaluate intersection over union (IOU) for target from dataset and output prediction from model - """ - if pred["boxes"].shape[0] == 0: - # no box detected, 0 IOU - return tensor(0.0, device=pred["boxes"].device) - return box_iou(target["boxes"], pred["boxes"]).diag().mean() - - -class ObjectDetector(Task): +class ObjectDetector(AdapterTask): """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see :ref:`object_detection`. Args: num_classes: the number of classes for detection, including background model: a string of :attr`_models`. Defaults to 'fasterrcnn'. - backbone: Pretained backbone CNN architecture. Constructs a model with a + backbone: Pretrained backbone CNN architecture. Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. pretrained: if true, returns a model pre-trained on COCO train2017 @@ -76,135 +46,40 @@ class ObjectDetector(Task): """ - backbones: FlashRegistry = OBJ_DETECTION_BACKBONES + heads: FlashRegistry = OBJECT_DETECTION_HEADS required_extras: str = "image" def __init__( self, num_classes: int, - model: str = "fasterrcnn", - backbone: Optional[str] = None, - fpn: bool = True, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "retinanet", pretrained: bool = True, - pretrained_backbone: bool = True, - trainable_backbone_layers: int = 3, - anchor_generator: Optional[Type['AnchorGenerator']] = None, - loss=None, - metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, - optimizer: Type[Optimizer] = torch.optim.AdamW, - learning_rate: float = 1e-3, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs: Any, ): self.save_hyperparameters() - if model in _models: - model = ObjectDetector.get_model( - model, num_classes, backbone, fpn, pretrained, pretrained_backbone, trainable_backbone_layers, - anchor_generator, **kwargs - ) - else: - ValueError(f"{model} is not supported yet.") + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **kwargs, + ) super().__init__( - model=model, - loss_fn=loss, - metrics=metrics, + adapter, learning_rate=learning_rate, optimizer=optimizer, - serializer=serializer or DetectionLabels(), + serializer=serializer, ) - @staticmethod - def get_model( - model_name, - num_classes, - backbone, - fpn, - pretrained, - pretrained_backbone, - trainable_backbone_layers, - anchor_generator, - **kwargs, - ): - if backbone is None: - # Constructs a model with a ResNet-50-FPN backbone when no backbone is specified. - if model_name == "fasterrcnn": - model = _models[model_name]( - pretrained=pretrained, - pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, - ) - in_features = model.roi_heads.box_predictor.cls_score.in_features - head = FastRCNNPredictor(in_features, num_classes) - model.roi_heads.box_predictor = head - else: - model = _models[model_name](pretrained=pretrained, pretrained_backbone=pretrained_backbone) - model.head = RetinaNetHead( - in_channels=model.backbone.out_channels, - num_anchors=model.head.classification_head.num_anchors, - num_classes=num_classes, - **kwargs - ) - else: - backbone_model, num_features = ObjectDetector.backbones.get(backbone)( - pretrained=pretrained_backbone, - trainable_layers=trainable_backbone_layers, - **kwargs, - ) - backbone_model.out_channels = num_features - if anchor_generator is None: - anchor_generator = AnchorGenerator( - sizes=((32, 64, 128, 256, 512), ), aspect_ratios=((0.5, 1.0, 2.0), ) - ) if not hasattr(backbone_model, "fpn") else None - - if model_name == "fasterrcnn": - model = FasterRCNN(backbone_model, num_classes=num_classes, rpn_anchor_generator=anchor_generator) - else: - model = RetinaNet(backbone_model, num_classes=num_classes, anchor_generator=anchor_generator) - return model - - def forward(self, x: List[torch.Tensor]) -> Any: - return self.model(x) - - def training_step(self, batch, batch_idx) -> Any: - """The training step. Overrides ``Task.training_step`` - """ - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - targets = [dict(t.items()) for t in targets] - - # fasterrcnn takes both images and targets for training, returns loss_dict - loss_dict = self.model(images, targets) - loss = sum(loss_dict.values()) - self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, prog_bar=True) - return loss - - def validation_step(self, batch, batch_idx): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - # fasterrcnn takes only images for eval() mode - outs = self(images) - iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() - self.log("val_iou", iou) - - def test_step(self, batch, batch_idx): - images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] - # fasterrcnn takes only images for eval() mode - outs = self(images) - iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean() - self.log("test_iou", iou) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - images = batch[DefaultDataKeys.INPUT] - batch[DefaultDataKeys.PREDS] = self(images) - return batch - - def configure_finetune_callback(self): - return [ObjectDetectionFineTuning(train_bn=True)] - def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: - """ - This function is used only for debugging usage with CI - """ - # todo (tchaton) Improve convergence - # history[-1]["val_iou"] + """This function is used only for debugging usage with CI.""" + # todo diff --git a/flash/image/detection/serialization.py b/flash/image/detection/serialization.py index 561fe0910d..e50614d0ef 100644 --- a/flash/image/detection/serialization.py +++ b/flash/image/detection/serialization.py @@ -17,7 +17,7 @@ from flash.core.data.data_source import DefaultDataKeys, LabelsState from flash.core.data.process import Serializer -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires Detections = None if _FIFTYONE_AVAILABLE: @@ -48,15 +48,13 @@ class FiftyOneDetectionLabels(Serializer): list of FiftyOne labels (False) """ + @requires("fiftyone") def __init__( self, labels: Optional[List[str]] = None, threshold: Optional[float] = None, return_filepath: bool = False, ): - if not _FIFTYONE_AVAILABLE: - raise ModuleNotFoundError("Please, run `pip install fiftyone`.") - super().__init__() self._labels = labels self.threshold = threshold @@ -89,7 +87,7 @@ def serialize(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] if self.threshold is not None and confidence < self.threshold: continue - xmin, ymin, xmax, ymax = [c.tolist() for c in det["boxes"]] + xmin, ymin, xmax, ymax = (c.tolist() for c in det["boxes"]) box = [ xmin / width, ymin / height, @@ -103,11 +101,13 @@ def serialize(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] else: label = str(int(label)) - detections.append(fo.Detection( - label=label, - bounding_box=box, - confidence=confidence, - )) + detections.append( + fo.Detection( + label=label, + bounding_box=box, + confidence=confidence, + ) + ) fo_predictions = fo.Detections(detections=detections) if self.return_filepath: filepath = sample[DefaultDataKeys.METADATA]["filepath"] diff --git a/flash/image/detection/transforms.py b/flash/image/detection/transforms.py deleted file mode 100644 index 1f54854376..0000000000 --- a/flash/image/detection/transforms.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Callable, Dict, Sequence - -import torch -from torch import nn - -from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE - -if _TORCHVISION_AVAILABLE: - import torchvision - - -def collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: - return {key: [sample[key] for sample in samples] for key in samples[0]} - - -def default_transforms() -> Dict[str, Callable]: - """The default transforms for object detection: convert the image and targets to a tensor, collate the batch.""" - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys('input', torchvision.transforms.ToTensor()), - ApplyToKeys( - 'target', - nn.Sequential( - ApplyToKeys('boxes', torch.as_tensor), - ApplyToKeys('labels', torch.as_tensor), - ApplyToKeys('image_id', torch.as_tensor), - ApplyToKeys('area', torch.as_tensor), - ApplyToKeys('iscrowd', torch.as_tensor), - ) - ), - ), - "collate": collate, - } diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 657bc3f65c..a8cab9b90a 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -23,17 +23,18 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.isinstance import _isinstance from flash.image.classification.data import ImageClassificationPreprocess if _IMAGE_AVAILABLE: - from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES + from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES else: IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") class ImageEmbedder(Task): - """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For more - details, see :ref:`image_embedder`. + """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For + more details, see :ref:`image_embedder`. Args: embedding_dim: Dimension of the embedded vector. ``None`` uses the default from the backbone. @@ -47,7 +48,6 @@ class ImageEmbedder(Task): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. pooling_fn: Function used to pool image to generate embeddings, defaults to :func:`torch.max`. - """ backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES @@ -63,7 +63,7 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Metric, Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, - pooling_fn: Callable = torch.max + pooling_fn: Callable = torch.max, ): super().__init__( model=None, @@ -71,7 +71,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - preprocess=ImageClassificationPreprocess() + preprocess=ImageClassificationPreprocess(), ) self.save_hyperparameters() @@ -89,14 +89,14 @@ def __init__( nn.Flatten(), nn.Linear(num_features, embedding_dim), ) - rank_zero_warn('Adding linear layer on top of backbone. Remember to finetune first before using!') + rank_zero_warn("Adding linear layer on top of backbone. Remember to finetune first before using!") def apply_pool(self, x): x = self.pooling_fn(x, dim=-1) - if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]): + if _isinstance(x, Tuple[torch.Tensor, torch.Tensor]): x = x[0] x = self.pooling_fn(x, dim=-1) - if torch.jit.isinstance(x, Tuple[torch.Tensor, torch.Tensor]): + if _isinstance(x, Tuple[torch.Tensor, torch.Tensor]): x = x[0] return x @@ -107,7 +107,7 @@ def forward(self, x) -> torch.Tensor: if isinstance(x, tuple): x = x[-1] - if x.dim() == 4 and self.embedding_dim: + if x.dim() == 4 and not self.embedding_dim: x = self.apply_pool(x) x = self.head(x) @@ -126,5 +126,5 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) + batch = batch[DefaultDataKeys.INPUT] return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) diff --git a/flash/image/instance_segmentation/__init__.py b/flash/image/instance_segmentation/__init__.py new file mode 100644 index 0000000000..c5659822c8 --- /dev/null +++ b/flash/image/instance_segmentation/__init__.py @@ -0,0 +1,2 @@ +from flash.image.instance_segmentation.data import InstanceSegmentationData # noqa: F401 +from flash.image.instance_segmentation.model import InstanceSegmentation # noqa: F401 diff --git a/flash/image/instance_segmentation/backbones.py b/flash/image/instance_segmentation/backbones.py new file mode 100644 index 0000000000..9811d6fa78 --- /dev/null +++ b/flash/image/instance_segmentation/backbones.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter, SimpleCOCOMetric +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, +) +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _ICEVISION, _MMDET, _TORCHVISION + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.metrics import COCOMetricType + from icevision.metrics import Metric as IceVisionMetric + +INSTANCE_SEGMENTATION_HEADS = FlashRegistry("heads") + + +class IceVisionInstanceSegmentationAdapter(IceVisionAdapter): + @classmethod + def from_task( + cls, + task: Task, + num_classes: int, + backbone: str = "resnet18_fpn", + head: str = "mask_rcnn", + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics or [SimpleCOCOMetric(COCOMetricType.mask)], + image_size=image_size, + **kwargs, + ) + + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.mask_rcnn + INSTANCE_SEGMENTATION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionInstanceSegmentationAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) + + if _module_available("mmdet"): + model_type = icevision_models.mmdet.mask_rcnn + INSTANCE_SEGMENTATION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + f"mmdet_{model_type.__name__.split('.')[-1]}", + backbones=get_backbones(model_type), + adapter=IceVisionInstanceSegmentationAdapter, + providers=[_ICEVISION, _MMDET], + ) diff --git a/flash/image/instance_segmentation/cli.py b/flash/image/instance_segmentation/cli.py new file mode 100644 index 0000000000..3b0842c436 --- /dev/null +++ b/flash/image/instance_segmentation/cli.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Callable, Optional + +from flash.core.utilities.flash_cli import FlashCLI +from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras +from flash.image import InstanceSegmentation, InstanceSegmentationData + +if _ICEDATA_AVAILABLE: + import icedata + +__all__ = ["instance_segmentation"] + + +@requires_extras("image") +def from_pets( + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + parser: Optional[Callable] = None, + **preprocess_kwargs, +) -> InstanceSegmentationData: + """Downloads and loads the pets data set from icedata.""" + data_dir = icedata.pets.load_data() + + if parser is None: + parser = partial(icedata.pets.parser, mask=True) + + return InstanceSegmentationData.from_folders( + train_folder=data_dir, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + parser=parser, + **preprocess_kwargs, + ) + + +def instance_segmentation(): + """Segment object instances in images.""" + cli = FlashCLI( + InstanceSegmentation, + InstanceSegmentationData, + default_datamodule_builder=from_pets, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("instance_segmentation_model.pt") + + +if __name__ == "__main__": + instance_segmentation() diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py new file mode 100644 index 0000000000..b67e606683 --- /dev/null +++ b/flash/image/instance_segmentation/data.py @@ -0,0 +1,234 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOMaskParser, VOCMaskParser + + +class InstanceSegmentationPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": IceVisionParserDataSource(parser=COCOMaskParser), + "voc": IceVisionParserDataSource(parser=VOCMaskParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + }, + default_data_source=DefaultDataSources.FILES, + ) + + self._default_collate = self._identity + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + +class InstanceSegmentationData(DataModule): + + preprocess_cls = InstanceSegmentationPreprocess + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the + given data folders and annotation files in the COCO format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = InstanceSegmentationData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + @classmethod + def from_voc( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the + given data folders and annotation files in the VOC format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = InstanceSegmentationData.from_voc( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "voc", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py new file mode 100644 index 0000000000..52f2706554 --- /dev/null +++ b/flash/image/instance_segmentation/model.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim import Optimizer + +from flash.core.adapter import AdapterTask +from flash.core.data.process import Serializer +from flash.core.registry import FlashRegistry +from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS + + +class InstanceSegmentation(AdapterTask): + """The ``InstanceSegmentation`` is a :class:`~flash.Task` for detecting objects in images. For more details, see + :ref:`object_detection`. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. + loss: the function(s) to update the model with. Has no effect for torchvision detection models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. + Has no effect for custom models. + learning_rate: The learning rate to use for training + + """ + + heads: FlashRegistry = INSTANCE_SEGMENTATION_HEADS + + required_extras: str = "image" + + def __init__( + self, + num_classes: int, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "mask_rcnn", + pretrained: bool = True, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **kwargs, + ) + + super().__init__( + adapter, + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer, + ) + + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: + """This function is used only for debugging usage with CI.""" + # todo diff --git a/flash/image/keypoint_detection/__init__.py b/flash/image/keypoint_detection/__init__.py new file mode 100644 index 0000000000..d397086e24 --- /dev/null +++ b/flash/image/keypoint_detection/__init__.py @@ -0,0 +1,2 @@ +from flash.image.keypoint_detection.data import KeypointDetectionData # noqa: F401 +from flash.image.keypoint_detection.model import KeypointDetector # noqa: F401 diff --git a/flash/image/keypoint_detection/backbones.py b/flash/image/keypoint_detection/backbones.py new file mode 100644 index 0000000000..72334761f2 --- /dev/null +++ b/flash/image/keypoint_detection/backbones.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Optional + +from flash.core.adapter import Adapter +from flash.core.integrations.icevision.adapter import IceVisionAdapter +from flash.core.integrations.icevision.backbones import ( + get_backbones, + icevision_model_adapter, + load_icevision_ignore_image_size, +) +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.providers import _ICEVISION, _TORCHVISION + +if _ICEVISION_AVAILABLE: + from icevision import models as icevision_models + from icevision.metrics import Metric as IceVisionMetric + +KEYPOINT_DETECTION_HEADS = FlashRegistry("heads") + + +class IceVisionKeypointDetectionAdapter(IceVisionAdapter): + @classmethod + def from_task( + cls, + task: Task, + num_keypoints: int, + num_classes: int = 2, + backbone: str = "resnet18_fpn", + head: str = "keypoint_rcnn", + pretrained: bool = True, + metrics: Optional["IceVisionMetric"] = None, + image_size: Optional = None, + **kwargs, + ) -> Adapter: + return super().from_task( + task, + num_keypoints=num_keypoints, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + metrics=metrics, + image_size=image_size, + **kwargs, + ) + + +if _ICEVISION_AVAILABLE: + if _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.keypoint_rcnn + KEYPOINT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, icevision_model_adapter, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionKeypointDetectionAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) diff --git a/flash/image/keypoint_detection/cli.py b/flash/image/keypoint_detection/cli.py new file mode 100644 index 0000000000..b97345679e --- /dev/null +++ b/flash/image/keypoint_detection/cli.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional + +from flash.core.utilities.flash_cli import FlashCLI +from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras +from flash.image import KeypointDetectionData, KeypointDetector + +if _ICEDATA_AVAILABLE: + import icedata + +__all__ = ["keypoint_detection"] + + +@requires_extras("image") +def from_biwi( + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + parser: Optional[Callable] = None, + **preprocess_kwargs, +) -> KeypointDetectionData: + """Downloads and loads the BIWI data set from icedata.""" + data_dir = icedata.biwi.load_data() + + if parser is None: + parser = icedata.biwi.parser + + return KeypointDetectionData.from_folders( + train_folder=data_dir, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + parser=parser, + **preprocess_kwargs, + ) + + +def keypoint_detection(): + """Detect keypoints in images.""" + cli = FlashCLI( + KeypointDetector, + KeypointDetectionData, + default_datamodule_builder=from_biwi, + default_arguments={ + "model.num_keypoints": 1, + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("keypoint_detection_model.pt") + + +if __name__ == "__main__": + keypoint_detection() diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py new file mode 100644 index 0000000000..48e4b06a44 --- /dev/null +++ b/flash/image/keypoint_detection/data.py @@ -0,0 +1,154 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.integrations.icevision.data import ( + IceDataParserDataSource, + IceVisionParserDataSource, + IceVisionPathsDataSource, +) +from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.utilities.imports import _ICEVISION_AVAILABLE + +if _ICEVISION_AVAILABLE: + from icevision.parsers import COCOKeyPointsParser + + +class KeypointDetectionPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + parser: Optional[Callable] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + "coco": IceVisionParserDataSource(parser=COCOKeyPointsParser), + DefaultDataSources.FILES: IceVisionPathsDataSource(), + DefaultDataSources.FOLDERS: IceDataParserDataSource(parser=parser), + }, + default_data_source=DefaultDataSources.FILES, + ) + + self._default_collate = self._identity + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.image_size) + + +class KeypointDetectionData(DataModule): + + preprocess_cls = KeypointDetectionPreprocess + + @classmethod + def from_coco( + cls, + train_folder: Optional[str] = None, + train_ann_file: Optional[str] = None, + val_folder: Optional[str] = None, + val_ann_file: Optional[str] = None, + test_folder: Optional[str] = None, + test_ann_file: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data + folders and annotation files in the COCO format. + + Args: + train_folder: The folder containing the train data. + train_ann_file: The COCO format annotation file. + val_folder: The folder containing the validation data. + val_ann_file: The COCO format annotation file. + test_folder: The folder containing the test data. + test_ann_file: The COCO format annotation file. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = KeypointDetectionData.from_coco( + train_folder="train_folder", + train_ann_file="annotations.json", + ) + """ + return cls.from_data_source( + "coco", + (train_folder, train_ann_file) if train_folder else None, + (val_folder, val_ann_file) if val_folder else None, + (test_folder, test_ann_file) if test_folder else None, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py new file mode 100644 index 0000000000..b85177d083 --- /dev/null +++ b/flash/image/keypoint_detection/model.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Mapping, Optional, Type, Union + +import torch +from torch.optim import Optimizer + +from flash.core.adapter import AdapterTask +from flash.core.data.process import Serializer +from flash.core.registry import FlashRegistry +from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS + + +class KeypointDetector(AdapterTask): + """The ``ObjectDetector`` is a :class:`~flash.Task` for detecting objects in images. For more details, see + :ref:`object_detection`. + + Args: + num_classes: the number of classes for detection, including background + model: a string of :attr`_models`. Defaults to 'fasterrcnn'. + backbone: Pretained backbone CNN architecture. Constructs a model with a + ResNet-50-FPN backbone when no backbone is specified. + fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs. + pretrained: if true, returns a model pre-trained on COCO train2017 + pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers: number of trainable resnet layers starting from final block. + Only applicable for `fasterrcnn`. + loss: the function(s) to update the model with. Has no effect for torchvision detection models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + pretrained: Whether the model from torchvision should be loaded with it's pretrained weights. + Has no effect for custom models. + learning_rate: The learning rate to use for training + + """ + + heads: FlashRegistry = KEYPOINT_DETECTION_HEADS + + required_extras: str = "image" + + def __init__( + self, + num_keypoints: int, + num_classes: int = 2, + backbone: Optional[str] = "resnet18_fpn", + head: Optional[str] = "keypoint_rcnn", + pretrained: bool = True, + optimizer: Type[Optimizer] = torch.optim.Adam, + learning_rate: float = 5e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + metadata = self.heads.get(head, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + num_keypoints=num_keypoints, + num_classes=num_classes, + backbone=backbone, + head=head, + pretrained=pretrained, + **kwargs, + ) + + super().__init__( + adapter, + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer, + ) + + def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: + """This function is used only for debugging usage with CI.""" + # todo diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index de6235cf11..0c73cc14fa 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -14,45 +14,31 @@ from functools import partial from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE -from flash.image.backbones import catch_url_error +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.core.utilities.providers import _SEGMENTATION_MODELS -if _TORCHVISION_AVAILABLE: - from torchvision.models import mobilenetv3, resnet - -MOBILENET_MODELS = ["mobilenet_v3_large"] -RESNET_MODELS = ["resnet50", "resnet101"] +if _SEGMENTATION_MODELS_AVAILABLE: + import segmentation_models_pytorch as smp SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") -if _TORCHVISION_AVAILABLE: +if _SEGMENTATION_MODELS_AVAILABLE: - def _load_resnet(model_name: str, pretrained: bool = True): - backbone = resnet.__dict__[model_name]( - pretrained=pretrained, - replace_stride_with_dilation=[False, True, True], - ) - return backbone + ENCODERS = smp.encoders.get_encoder_names() - for model_name in RESNET_MODELS: - SEMANTIC_SEGMENTATION_BACKBONES( - fn=catch_url_error(partial(_load_resnet, model_name)), - name=model_name, - namespace="image/segmentation", - package="torchvision", - ) - - def _load_mobilenetv3(model_name: str, pretrained: bool = True): - backbone = mobilenetv3.__dict__[model_name]( - pretrained=pretrained, - _dilated=True, - ) + def _load_smp_backbone(backbone: str, **_) -> str: return backbone - for model_name in MOBILENET_MODELS: + for encoder_name in ENCODERS: + short_name = encoder_name + if short_name.startswith("timm-"): + short_name = encoder_name[5:] + + available_weights = smp.encoders.encoders[encoder_name]["pretrained_settings"].keys() SEMANTIC_SEGMENTATION_BACKBONES( - fn=catch_url_error(partial(_load_mobilenetv3, model_name)), - name=model_name, + partial(_load_smp_backbone, backbone=encoder_name), + name=short_name, namespace="image/segmentation", - package="torchvision", + weights_paths=available_weights, + providers=_SEGMENTATION_MODELS, ) diff --git a/flash/image/segmentation/cli.py b/flash/image/segmentation/cli.py new file mode 100644 index 0000000000..64cb0c3d93 --- /dev/null +++ b/flash/image/segmentation/cli.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.image import SemanticSegmentation, SemanticSegmentationData + +__all__ = ["semantic_segmentation"] + + +def from_carla( + num_classes: int = 21, + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> SemanticSegmentationData: + """Downloads and loads the CARLA capture data set.""" + download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", + "./data", + ) + return SemanticSegmentationData.from_folders( + train_folder="data/CameraRGB", + train_target_folder="data/CameraSeg", + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + num_classes=num_classes, + **preprocess_kwargs, + ) + + +def semantic_segmentation(): + """Segment objects in images.""" + cli = FlashCLI( + SemanticSegmentation, + SemanticSegmentationData, + default_datamodule_builder=from_carla, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("semantic_segmentation_model.pt") + + +if __name__ == "__main__": + semantic_segmentation() diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index d933690a95..6b39ee1450 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -38,10 +38,11 @@ from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, - _PIL_AVAILABLE, - _requires_extras, _TORCHVISION_AVAILABLE, + Image, lazy_import, + requires, + requires_extras, ) from flash.image.data import ImageDeserializer from flash.image.segmentation.serialization import SegmentationLabels @@ -62,20 +63,13 @@ if _TORCHVISION_AVAILABLE: import torchvision - from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS + import torchvision.transforms.functional as FT + from torchvision.datasets.folder import default_loader, has_file_allowed_extension, IMG_EXTENSIONS else: IMG_EXTENSIONS = None -if _PIL_AVAILABLE: - from PIL import Image -else: - - class Image: - Image = None - class SemanticSegmentationNumpyDataSource(NumpyDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() sample[DefaultDataKeys.INPUT] = img @@ -84,7 +78,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class SemanticSegmentationTensorDataSource(TensorDataSource): - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: img = sample[DefaultDataKeys.INPUT].float() sample[DefaultDataKeys.INPUT] = img @@ -93,13 +86,13 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class SemanticSegmentationPathsDataSource(PathsDataSource): - - @_requires_extras("image") + @requires_extras("image") def __init__(self): super().__init__(IMG_EXTENSIONS) - def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], - dataset: BaseAutoDataset) -> Sequence[Mapping[str, Any]]: + def load_data( + self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], dataset: BaseAutoDataset + ) -> Sequence[Mapping[str, Any]]: input_data, target_data = data if self.isdir(input_data) and self.isdir(target_data): @@ -130,8 +123,8 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], data = filter( lambda sample: ( - has_file_allowed_extension(sample[0], self.extensions) and - has_file_allowed_extension(sample[1], self.extensions) + has_file_allowed_extension(sample[0], self.extensions) + and has_file_allowed_extension(sample[1], self.extensions) ), zip(input_data, target_data), ) @@ -149,7 +142,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten img_labels_path = sample[DefaultDataKeys.TARGET] # load images directly to torch tensors - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW @@ -164,7 +157,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = torchvision.io.read_image(img_path).float() + img = FT.to_tensor(default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = { @@ -175,8 +168,7 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationFiftyOneDataSource(FiftyOneDataSource): - - @_requires_extras("image") + @requires_extras("image") def __init__(self, label_field: str = "ground_truth"): super().__init__(label_field=label_field) self._fo_dataset_name = None @@ -197,7 +189,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten img_path = sample[DefaultDataKeys.INPUT] fo_sample = _fo_dataset[img_path] - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img: torch.Tensor = FT.to_tensor(default_loader(img_path)) # CxHxW img_labels: torch.Tensor = torch.from_numpy(fo_sample[self.label_field].mask) # HxW sample[DefaultDataKeys.INPUT] = img.float() @@ -211,7 +203,7 @@ def load_sample(self, sample: Mapping[str, str]) -> Mapping[str, Union[torch.Ten @staticmethod def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: img_path = sample[DefaultDataKeys.INPUT] - img = torchvision.io.read_image(img_path).float() + img = FT.to_tensor(default_loader(img_path)).float() sample[DefaultDataKeys.INPUT] = img sample[DefaultDataKeys.METADATA] = { @@ -222,7 +214,6 @@ def predict_load_sample(sample: Mapping[str, Any]) -> Mapping[str, Any]: class SemanticSegmentationDeserializer(ImageDeserializer): - def deserialize(self, data: str) -> torch.Tensor: result = super().deserialize(data) result[DefaultDataKeys.INPUT] = self.to_tensor(result[DefaultDataKeys.INPUT]) @@ -231,16 +222,15 @@ def deserialize(self, data: str) -> torch.Tensor: class SemanticSegmentationPreprocess(Preprocess): - - @_requires_extras("image") + @requires_extras("image") def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - image_size: Tuple[int, int] = (196, 196), - deserializer: Optional['Deserializer'] = None, + image_size: Tuple[int, int] = (128, 128), + deserializer: Optional["Deserializer"] = None, num_classes: int = None, labels_map: Dict[int, Tuple[int, int, int]] = None, **data_source_kwargs: Any, @@ -283,9 +273,10 @@ def __init__( def get_state_dict(self) -> Dict[str, Any]: return { - **self.transforms, "image_size": self.image_size, + **self.transforms, + "image_size": self.image_size, "num_classes": self.num_classes, - "labels_map": self.labels_map + "labels_map": self.labels_map, } @classmethod @@ -307,7 +298,7 @@ class SemanticSegmentationData(DataModule): @staticmethod def configure_data_fetcher( labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None - ) -> 'SegmentationMatplotlibVisualization': + ) -> "SegmentationMatplotlibVisualization": return SegmentationMatplotlibVisualization(labels_map=labels_map) def set_block_viz_window(self, value: bool) -> None: @@ -332,22 +323,23 @@ def from_data_source( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs: Any, - ) -> 'DataModule': + ) -> "DataModule": - if 'num_classes' not in preprocess_kwargs: + if "num_classes" not in preprocess_kwargs: raise MisconfigurationException("`num_classes` should be provided during instantiation.") num_classes = preprocess_kwargs["num_classes"] - labels_map = getattr(preprocess_kwargs, "labels_map", - None) or SegmentationLabels.create_random_labels_map(num_classes) + labels_map = getattr(preprocess_kwargs, "labels_map", None) or SegmentationLabels.create_random_labels_map( + num_classes + ) data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) if flash._IS_TESTING: data_fetcher.block_viz_window = True - dm = super(SemanticSegmentationData, cls).from_data_source( + dm = super().from_data_source( data_source=data_source, train_data=train_data, val_data=val_data, @@ -362,7 +354,7 @@ def from_data_source( val_split=val_split, batch_size=batch_size, num_workers=num_workers, - **preprocess_kwargs + **preprocess_kwargs, ) if dm.train_dataset is not None: @@ -391,7 +383,7 @@ def from_folders( num_classes: Optional[int] = None, labels_map: Dict[int, Tuple[int, int, int]] = None, **preprocess_kwargs, - ) -> 'DataModule': + ) -> "DataModule": """Creates a :class:`~flash.image.segmentation.data.SemanticSegmentationData` object from the given data folders and corresponding target folders. @@ -459,8 +451,7 @@ def from_folders( class SegmentationMatplotlibVisualization(BaseVisualization): - """Process and show the image batch and its associated label using matplotlib. - """ + """Process and show the image batch and its associated label using matplotlib.""" def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]): super().__init__() @@ -470,7 +461,7 @@ def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]): self.labels_map: Dict[int, Tuple[int, int, int]] = labels_map @staticmethod - @_requires_extras("image") + @requires_extras("image") def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: out: np.ndarray if isinstance(img, Image.Image): @@ -481,7 +472,7 @@ def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: raise TypeError(f"Unknown image type. Got: {type(img)}.") return out - @_requires_extras("image") + @requires("matplotlib") def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): # define the image grid cols: int = min(num_samples, self.max_cols) @@ -509,7 +500,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) img_vis = np.hstack((image_vis, label_vis)) # send to visualiser ax.imshow(img_vis) - ax.axis('off') + ax.axis("off") plt.show(block=self.block_viz_window) def show_load_sample(self, samples: List[Any], running_stage: RunningStage): diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 97fd125dfd..4886dade8f 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -11,103 +11,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings from functools import partial +from typing import Union -import torch.nn as nn -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import nn from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE - -if _TORCHVISION_AVAILABLE: - from torchvision.models import MobileNetV3, ResNet - from torchvision.models._utils import IntermediateLayerGetter - from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3 - from torchvision.models.segmentation.fcn import FCN, FCNHead - from torchvision.models.segmentation.lraspp import LRASPP - -if _BOLTS_AVAILABLE: - if os.getenv("WARN_MISSING_PACKAGE") == "0": - with warnings.catch_warnings(record=True) as w: - from pl_bolts.models.vision import UNet - else: - from pl_bolts.models.vision import UNet +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.core.utilities.providers import _SEGMENTATION_MODELS + +if _SEGMENTATION_MODELS_AVAILABLE: + import segmentation_models_pytorch as smp + + SMP_MODEL_CLASS = [ + smp.Unet, + smp.UnetPlusPlus, + smp.MAnet, + smp.Linknet, + smp.FPN, + smp.PSPNet, + smp.DeepLabV3, + smp.DeepLabV3Plus, + smp.PAN, + ] + SMP_MODELS = {a.__name__.lower(): a for a in SMP_MODEL_CLASS} SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones") -if _TORCHVISION_AVAILABLE: - - def _get_backbone_meta(backbone): - """Adapted from torchvision.models.segmentation.segmentation._segm_model: - https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25 - """ - if isinstance(backbone, ResNet): - out_layer = 'layer4' - out_inplanes = 2048 - aux_layer = 'layer3' - aux_inplanes = 1024 - elif isinstance(backbone, MobileNetV3): - backbone = backbone.features - # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. - # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] - stage_indices = [0] + stage_indices + [len(backbone) - 1] - out_pos = stage_indices[-1] # use C5 which has output_stride = 16 - out_layer = str(out_pos) - out_inplanes = backbone[out_pos].out_channels - aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 - aux_layer = str(aux_pos) - aux_inplanes = backbone[aux_pos].out_channels - else: - raise MisconfigurationException( - f"{type(backbone)} backbone is not currently supported for semantic segmentation." - ) - return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes - - def _load_fcn_deeplabv3(model_name, backbone, num_classes): - backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone) - - return_layers = {out_layer: 'out'} - backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) - - model_map = { - "deeplabv3": (DeepLabHead, DeepLabV3), - "fcn": (FCNHead, FCN), - } - classifier = model_map[model_name][0](out_inplanes, num_classes) - base_model = model_map[model_name][1] - - return base_model(backbone, classifier, None) +if _SEGMENTATION_MODELS_AVAILABLE: + + def _load_smp_head( + head: str, + backbone: str, + pretrained: Union[bool, str] = True, + num_classes: int = 1, + in_channels: int = 3, + **kwargs, + ) -> nn.Module: + + if head not in SMP_MODELS: + raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}") + + encoder_weights = None + if isinstance(pretrained, str): + encoder_weights = pretrained + elif pretrained: + encoder_weights = "imagenet" + + return smp.create_model( + arch=head, + encoder_name=backbone, + encoder_weights=encoder_weights, + classes=num_classes, + in_channels=in_channels, + **kwargs, + ) - for model_name in ["fcn", "deeplabv3"]: + for model_name in SMP_MODELS: SEMANTIC_SEGMENTATION_HEADS( - fn=partial(_load_fcn_deeplabv3, model_name), + partial(_load_smp_head, head=model_name), name=model_name, namespace="image/segmentation", - package="torchvision", + providers=_SEGMENTATION_MODELS, ) - - def _load_lraspp(backbone, num_classes): - backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone) - backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'}) - return LRASPP(backbone, low_channels, high_channels, num_classes) - - SEMANTIC_SEGMENTATION_HEADS( - fn=_load_lraspp, - name="lraspp", - namespace="image/segmentation", - package="torchvision", - ) - -if _BOLTS_AVAILABLE: - - def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module: - rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning) - return UNet(num_classes, **kwargs) - - SEMANTIC_SEGMENTATION_HEADS( - fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet" - ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 1951421315..771014bbb5 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -23,6 +23,7 @@ from flash.core.data.process import Postprocess, Serializer from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _KORNIA_AVAILABLE +from flash.core.utilities.isinstance import _isinstance from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS from flash.image.segmentation.serialization import SegmentationLabels @@ -32,9 +33,8 @@ class SemanticSegmentationPostprocess(Postprocess): - def per_sample_transform(self, sample: Any) -> Any: - resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation='bilinear') + resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear") sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS])) sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT])) return super().per_sample_transform(sample) @@ -75,9 +75,9 @@ def __init__( num_classes: int, backbone: Union[str, nn.Module] = "resnet50", backbone_kwargs: Optional[Dict] = None, - head: str = "fcn", + head: str = "fpn", head_kwargs: Optional[Dict] = None, - pretrained: bool = True, + pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -103,7 +103,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, serializer=serializer or SegmentationLabels(), - postprocess=postprocess or self.postprocess_cls() + postprocess=postprocess or self.postprocess_cls(), ) self.save_hyperparameters() @@ -117,9 +117,12 @@ def __init__( if isinstance(backbone, nn.Module): self.backbone = backbone else: - self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + self.backbone = self.backbones.get(backbone)(**backbone_kwargs) - self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs) + self.head: nn.Module = self.heads.get(head)( + backbone=self.backbone, num_classes=num_classes, pretrained=pretrained, **head_kwargs + ) + self.backbone = self.head.encoder def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -134,7 +137,7 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch_input = (batch[DefaultDataKeys.INPUT]) + batch_input = batch[DefaultDataKeys.INPUT] batch[DefaultDataKeys.PREDS] = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx) return batch @@ -144,18 +147,22 @@ def forward(self, x) -> torch.Tensor: # some frameworks like torchvision return a dict. # In particular, torchvision segmentation models return the output logits # in the key `out`. - if torch.jit.isinstance(res, Dict[str, torch.Tensor]): - out = res['out'] - elif torch.is_tensor(res): - out = res - else: - raise NotImplementedError(f"Unsupported output type: {type(res)}") + if _isinstance(res, Dict[str, torch.Tensor]): + res = res["out"] + + return res + + @classmethod + def available_pretrained_weights(cls, backbone: str): + result = cls.backbones.get(backbone, with_metadata=True) + pretrained_weights = None + + if "weights_paths" in result["metadata"]: + pretrained_weights = list(result["metadata"]["weights_paths"]) - return out + return pretrained_weights @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" assert history[-1]["val_iou"] > 0.2 diff --git a/flash/image/segmentation/serialization.py b/flash/image/segmentation/serialization.py index 16d51beb63..8bc893fce3 100644 --- a/flash/image/segmentation/serialization.py +++ b/flash/image/segmentation/serialization.py @@ -19,7 +19,14 @@ import flash from flash.core.data.data_source import DefaultDataKeys, ImageLabelsMap from flash.core.data.process import Serializer -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE, lazy_import +from flash.core.utilities.imports import ( + _FIFTYONE_AVAILABLE, + _KORNIA_AVAILABLE, + _MATPLOTLIB_AVAILABLE, + lazy_import, + requires, + requires_extras, +) Segmentation = None if _FIFTYONE_AVAILABLE: @@ -41,15 +48,15 @@ class SegmentationLabels(Serializer): - """A :class:`.Serializer` which converts the model outputs to the label of - the argmax classification per pixel in the image for semantic segmentation - tasks. + """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification per pixel + in the image for semantic segmentation tasks. Args: labels_map: A dictionary that map the labels ids to pixel intensities. visualize: Wether to visualize the image labels. """ + @requires_extras("image") def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False): super().__init__() self.labels_map = labels_map @@ -57,14 +64,13 @@ def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, @staticmethod def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: - """Function that given an image with labels ids and their pixels intrensity mapping, - creates a RGB representation for visualisation purposes. - """ + """Function that given an image with labels ids and their pixels intrensity mapping, creates a RGB + representation for visualisation purposes.""" assert len(img_labels.shape) == 2, img_labels.shape H, W = img_labels.shape out = torch.empty(3, H, W, dtype=torch.uint8) for label_id, label_val in labels_map.items(): - mask = (img_labels == label_id) + mask = img_labels == label_id for i in range(3): out[i].masked_fill_(mask, label_val[i]) return out @@ -73,27 +79,30 @@ def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, i def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: labels_map: Dict[int, Tuple[int, int, int]] = {} for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) + labels_map[i] = torch.randint(0, 255, (3,)) return labels_map + @requires("matplotlib") + def _visualize(self, labels): + if self.labels_map is None: + self.labels_map = self.get_state(ImageLabelsMap).labels_map + labels_vis = self.labels_to_image(labels, self.labels_map) + labels_vis = K.utils.tensor_to_image(labels_vis) + plt.imshow(labels_vis) + plt.show() + def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: preds = sample[DefaultDataKeys.PREDS] assert len(preds.shape) == 3, preds.shape labels = torch.argmax(preds, dim=-3) # HxW if self.visualize and not flash._IS_TESTING: - if self.labels_map is None: - self.labels_map = self.get_state(ImageLabelsMap).labels_map - labels_vis = self.labels_to_image(labels, self.labels_map) - labels_vis = K.utils.tensor_to_image(labels_vis) - plt.imshow(labels_vis) - plt.show() + self._visualize(labels) return labels.tolist() class FiftyOneSegmentationLabels(SegmentationLabels): - """A :class:`.Serializer` which converts the model outputs to FiftyOne - segmentation format. + """A :class:`.Serializer` which converts the model outputs to FiftyOne segmentation format. Args: labels_map: A dictionary that map the labels ids to pixel intensities. @@ -103,15 +112,13 @@ class FiftyOneSegmentationLabels(SegmentationLabels): FiftyOne labels (False). """ + @requires("fiftyone") def __init__( self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False, return_filepath: bool = False, ): - if not _FIFTYONE_AVAILABLE: - raise ModuleNotFoundError("Please, run `pip install fiftyone`.") - super().__init__(labels_map=labels_map, visualize=visualize) self.return_filepath = return_filepath diff --git a/flash/image/segmentation/transforms.py b/flash/image/segmentation/transforms.py index 92ef2b45bd..53bd0a6314 100644 --- a/flash/image/segmentation/transforms.py +++ b/flash/image/segmentation/transforms.py @@ -29,7 +29,7 @@ def prepare_target(tensor: torch.Tensor) -> torch.Tensor: - """ Convert the target mask to long and remove the channel dimension. """ + """Convert the target mask to long and remove the channel dimension.""" return tensor.long().squeeze(1) @@ -40,7 +40,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: "post_tensor_transform": nn.Sequential( ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], - KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')), + KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation="nearest")), ), ), "collate": Compose([kornia_collate, ApplyToKeys(DefaultDataKeys.TARGET, prepare_target)]), @@ -48,14 +48,16 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: - """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and ``ColorJitter``.""" + """During training, we apply the default transforms with additional ``RandomHorizontalFlip`` and + ``ColorJitter``.""" return merge_transforms( - default_transforms(image_size), { + default_transforms(image_size), + { "post_tensor_transform": nn.Sequential( ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)), ), ), - } + }, ) diff --git a/flash/image/style_transfer/backbones.py b/flash/image/style_transfer/backbones.py index b9437e64ff..07c05f1ca1 100644 --- a/flash/image/style_transfer/backbones.py +++ b/flash/image/style_transfer/backbones.py @@ -15,6 +15,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYSTICHE_AVAILABLE +from flash.core.utilities.providers import _PYSTICHE STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") @@ -26,8 +27,6 @@ MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") - STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") - for mle_fn in dir(enc): match = MLE_FN_PATTERN.match(mle_fn) if not match: @@ -37,5 +36,5 @@ fn=lambda: (getattr(enc, mle_fn)(), None), name=match.group("name"), namespace="image/style_transfer", - package="pystiche", + providers=_PYSTICHE, ) diff --git a/flash/image/style_transfer/cli.py b/flash/image/style_transfer/cli.py new file mode 100644 index 0000000000..0fec347021 --- /dev/null +++ b/flash/image/style_transfer/cli.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +import flash +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.image import StyleTransfer, StyleTransferData + +__all__ = ["style_transfer"] + + +def from_coco_128( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> StyleTransferData: + """Downloads and loads the COCO 128 data set.""" + download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") + return StyleTransferData.from_folders( + train_folder="data/coco128/images/train2017/", + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def style_transfer(): + """Image style transfer.""" + cli = FlashCLI( + StyleTransfer, + StyleTransferData, + default_datamodule_builder=from_coco_128, + default_arguments={ + "trainer.max_epochs": 3, + "model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"), + }, + finetune=False, + ) + + cli.trainer.save_checkpoint("style_transfer_model.pt") + + +if __name__ == "__main__": + style_transfer() diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 75ab6f9e7a..f9f63c5905 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -17,6 +17,7 @@ from torch import nn +from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources from flash.core.data.process import Preprocess from flash.core.data.transforms import ApplyToKeys @@ -31,9 +32,9 @@ __all__ = ["StyleTransferPreprocess", "StyleTransferData"] -def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], - DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]: - +def _apply_to_input( + default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], DefaultDataKeys] +) -> Callable[..., Dict[str, ApplyToKeys]]: @functools.wraps(default_transforms_fn) def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: default_transforms = default_transforms_fn(*args, **kwargs) @@ -46,7 +47,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]: class StyleTransferPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Union[Dict[str, Callable]]] = None, @@ -118,12 +118,12 @@ def from_folders( predict_transform: Optional[Union[str, Dict]] = None, preprocess: Optional[Preprocess] = None, **kwargs: Any, - ) -> "StyleTransferData": + ) -> "DataModule": - if any(param in kwargs for param in ("val_folder", "val_transform")): + if any(param in kwargs and kwargs[param] is not None for param in ("val_folder", "val_transform")): raise_not_supported("validation") - if any(param in kwargs for param in ("test_folder", "test_transform")): + if any(param in kwargs and kwargs[param] is not None for param in ("test_folder", "test_transform")): raise_not_supported("test") preprocess = preprocess or cls.preprocess_cls( diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index 1573a10612..86a6b723e5 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -26,7 +26,7 @@ if _IMAGE_AVAILABLE: import pystiche.demo - from pystiche import enc, loss, ops + from pystiche import enc, loss from pystiche.image import read_image else: @@ -34,12 +34,9 @@ class enc: Encoder = None MultiLayerEncoder = None - class ops: - EncodingComparisonOperator = None - FeatureReconstructionOperator = None - MultiLayerEncodingOperator = None - class loss: + class GramLoss: + pass class PerceptualLoss: pass @@ -80,7 +77,7 @@ def __init__( backbone: str = "vgg16", content_layer: str = "relu2_2", content_weight: float = 1e5, - style_layers: Union[Sequence[str], str] = ("relu1_2", "relu2_2", "relu3_3", "relu4_3"), + style_layers: Union[Sequence[str], str] = ["relu1_2", "relu2_2", "relu3_3", "relu4_3"], style_weight: float = 1e10, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, @@ -100,7 +97,7 @@ def __init__( model = pystiche.demo.transformer() if not isinstance(style_layers, (List, Tuple)): - style_layers = (style_layers, ) + style_layers = (style_layers,) perceptual_loss = self._get_perceptual_loss( backbone=backbone, @@ -129,12 +126,11 @@ def default_style_image() -> torch.Tensor: return pystiche.demo.images()["paint"].read(size=256) @staticmethod - def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator: + def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> loss.GramLoss: # The official PyTorch examples as well as the reference implementation of the original author contain an # oversight: they normalize the representation twice by the number of channels. To be compatible with them, we # do the same here. - class GramOperator(ops.GramOperator): - + class GramOperator(loss.GramLoss): def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor: repr = super().enc_to_repr(enc) num_channels = repr.size()[1] @@ -152,10 +148,8 @@ def _get_perceptual_loss( style_weight: float, ) -> loss.PerceptualLoss: mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)()) - content_loss = ops.FeatureReconstructionOperator( - mle.extract_encoder(content_layer), score_weight=content_weight - ) - style_loss = ops.MultiLayerEncodingOperator( + content_loss = loss.FeatureReconstructionLoss(mle.extract_encoder(content_layer), score_weight=content_weight) + style_loss = loss.MultiLayerEncodingLoss( mle, style_layers, lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight), diff --git a/flash/pointcloud/__init__.py b/flash/pointcloud/__init__.py new file mode 100644 index 0000000000..766f2f2e89 --- /dev/null +++ b/flash/pointcloud/__init__.py @@ -0,0 +1,2 @@ +from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData # noqa: F401 +from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData # noqa: F401 diff --git a/flash/pointcloud/detection/__init__.py b/flash/pointcloud/detection/__init__.py new file mode 100644 index 0000000000..cfe4c690f0 --- /dev/null +++ b/flash/pointcloud/detection/__init__.py @@ -0,0 +1,3 @@ +from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401 +from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401 +from flash.pointcloud.detection.open3d_ml.app import launch_app # noqa: F401 diff --git a/flash/pointcloud/detection/backbones.py b/flash/pointcloud/detection/backbones.py new file mode 100644 index 0000000000..88268dd036 --- /dev/null +++ b/flash/pointcloud/detection/backbones.py @@ -0,0 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import FlashRegistry +from flash.pointcloud.detection.open3d_ml.backbones import register_open_3d_ml + +POINTCLOUD_OBJECT_DETECTION_BACKBONES = FlashRegistry("backbones") + +register_open_3d_ml(POINTCLOUD_OBJECT_DETECTION_BACKBONES) diff --git a/flash/pointcloud/detection/cli.py b/flash/pointcloud/detection/cli.py new file mode 100644 index 0000000000..01a4c329ce --- /dev/null +++ b/flash/pointcloud/detection/cli.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData + +__all__ = ["pointcloud_detection"] + + +def from_kitti( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> PointCloudObjectDetectorData: + """Downloads and loads the KITTI data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/") + return PointCloudObjectDetectorData.from_folders( + train_folder="data/KITTI_Tiny/Kitti/train", + val_folder="data/KITTI_Tiny/Kitti/val", + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def pointcloud_detection(): + """Detect objects in point clouds.""" + cli = FlashCLI( + PointCloudObjectDetector, + PointCloudObjectDetectorData, + default_datamodule_builder=from_kitti, + default_arguments={ + "trainer.max_epochs": 3, + }, + finetune=False, + ) + + cli.trainer.save_checkpoint("pointcloud_detection_model.pt") + + +if __name__ == "__main__": + pointcloud_detection() diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py new file mode 100644 index 0000000000..40349b8653 --- /dev/null +++ b/flash/pointcloud/detection/data.py @@ -0,0 +1,176 @@ +from typing import Any, Callable, Dict, Optional, Type + +from torch.utils.data import Sampler + +from flash.core.data.base_viz import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras + +if _POINTCLOUD_AVAILABLE: + from flash.pointcloud.detection.open3d_ml.data_sources import ( + PointCloudObjectDetectionDataFormat, + PointCloudObjectDetectorFoldersDataSource, + ) +else: + PointCloudObjectDetectorFoldersDataSource = object + + class PointCloudObjectDetectionDataFormat: + KITTI = None + + +class PointCloudObjectDetectorDatasetDataSource(DataSource): + def __init__(self, **kwargs): + super().__init__() + + def load_data( + self, + data: Any, + dataset: Optional[Any] = None, + ) -> Any: + + dataset.dataset = data + + return range(len(data)) + + def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: + sample = dataset.dataset[index] + + return { + DefaultDataKeys.INPUT: sample["data"], + DefaultDataKeys.METADATA: sample["attr"], + } + + +class PointCloudObjectDetectorPreprocess(Preprocess): + @requires_extras("pointcloud") + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + deserializer: Optional[Deserializer] = None, + **data_source_kwargs, + ): + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATASETS: PointCloudObjectDetectorDatasetDataSource(**data_source_kwargs), + DefaultDataSources.FOLDERS: PointCloudObjectDetectorFoldersDataSource(**data_source_kwargs), + }, + deserializer=deserializer, + default_data_source=DefaultDataSources.FOLDERS, + ) + + def get_state_dict(self): + return {} + + def state_dict(self): + return {} + + @classmethod + def load_state_dict(cls, state_dict, strict: bool = False): + pass + + +class PointCloudObjectDetectorData(DataModule): + + preprocess_cls = PointCloudObjectDetectorPreprocess + + @classmethod + def from_folders( + cls, + train_folder: Optional[str] = None, + val_folder: Optional[str] = None, + test_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + sampler: Optional[Type[Sampler]] = None, + scans_folder_name: Optional[str] = "scans", + labels_folder_name: Optional[str] = "labels", + calibrations_folder_name: Optional[str] = "calibs", + data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the + :class:`~flash.core.data.data_source.DataSource` of name + :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` + from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + + Args: + train_folder: The folder containing the train data. + val_folder: The folder containing the validation data. + test_folder: The folder containing the test data. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + scans_folder_name: The name of the pointcloud scan folder + labels_folder_name: The name of the pointcloud scan labels folder + calibrations_folder_name: The name of the pointcloud scan calibration folder + data_format: Format in which the data are stored. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_folders( + train_folder="train_folder", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ + return cls.from_data_source( + DefaultDataSources.FOLDERS, + train_folder, + val_folder, + test_folder, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + scans_folder_name=scans_folder_name, + labels_folder_name=labels_folder_name, + calibrations_folder_name=calibrations_folder_name, + data_format=data_format, + **preprocess_kwargs, + ) diff --git a/flash/pointcloud/detection/datasets.py b/flash/pointcloud/detection/datasets.py new file mode 100644 index 0000000000..335f699757 --- /dev/null +++ b/flash/pointcloud/detection/datasets.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.pointcloud.segmentation.datasets import executor + +if _POINTCLOUD_AVAILABLE: + from open3d.ml.datasets import KITTI + +_OBJECT_DETECTION_DATASET = FlashRegistry("dataset") + + +@_OBJECT_DETECTION_DATASET +def kitti(dataset_path, download, **kwargs): + name = "KITTI" + download_path = os.path.join(dataset_path, name, "Kitti") + if not os.path.exists(download_path): + executor( + "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_kitti.sh", # noqa E501 + None, + dataset_path, + name, + ) + return KITTI(download_path, **kwargs) + + +def KITTIDataset(dataset_path, download: bool = True, **kwargs): + return _OBJECT_DETECTION_DATASET.get("kitti")(dataset_path, download, **kwargs) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py new file mode 100644 index 0000000000..155126d785 --- /dev/null +++ b/flash/pointcloud/detection/model.py @@ -0,0 +1,183 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +import torchmetrics +from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, Sampler + +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer +from flash.core.data.states import CollateFn +from flash.core.model import Task +from flash.core.registry import FlashRegistry +from flash.core.utilities.apply_func import get_callable_dict +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES + +__FILE_EXAMPLE__ = "pointcloud_detection" + + +class PointCloudObjectDetectorSerializer(Serializer): + pass + + +class PointCloudObjectDetector(Task): + """The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies + pointcloud data. + + Args: + num_features: The number of features (elements) in the input data. + num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`. + backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use. + backbone_kwargs: Any additional kwargs to pass to the backbone constructor. + loss_fn: The loss function to use. If ``None``, a default will be selected by the + :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + optimizer: The optimizer or optimizer class to use. + optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). + scheduler: The scheduler or scheduler class to use. + scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected + by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + learning_rate: The learning rate for the optimizer. + multi_label: If ``True``, this will be treated as a multi-label classification problem. + serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + lambda_loss_cls: The value to scale the loss classification. + lambda_loss_bbox: The value to scale the bounding boxes loss. + lambda_loss_dir: The value to scale the bounding boxes direction loss. + """ + + backbones: FlashRegistry = POINTCLOUD_OBJECT_DETECTION_BACKBONES + required_extras: str = "pointcloud" + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "pointpillars_kitti", + backbone_kwargs: Optional[Dict] = None, + head: Optional[nn.Module] = None, + loss_fn: Optional[Callable] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(), + lambda_loss_cls: float = 1.0, + lambda_loss_bbox: float = 1.0, + lambda_loss_dir: float = 1.0, + ): + + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.save_hyperparameters() + + if backbone_kwargs is None: + backbone_kwargs = {} + + if isinstance(backbone, tuple): + self.backbone, out_features = backbone + else: + self.model, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) + self.backbone = self.model.backbone + self.neck = self.model.neck + self.set_state(CollateFn(collate_fn)) + self.set_state(CollateFn(collate_fn)) + self.set_state(CollateFn(collate_fn)) + self.loss_fn = get_callable_dict(self.model.loss) + + if __FILE_EXAMPLE__ not in sys.argv[0]: + self.model.bbox_head.conv_cls = self.head = nn.Conv2d( + out_features, num_classes, kernel_size=(1, 1), stride=(1, 1) + ) + + def compute_loss(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + losses = losses["loss"] + return ( + self.hparams.lambda_loss_cls * losses["loss_cls"] + + self.hparams.lambda_loss_bbox * losses["loss_bbox"] + + self.hparams.lambda_loss_dir * losses["loss_dir"] + ) + + def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]): + logs.update({"loss": self.compute_loss(losses)}) + return logs + + def training_step(self, batch: Any, batch_idx: int) -> Any: + return super().training_step((batch, batch), batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + super().validation_step((batch, batch), batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + super().validation_step((batch, batch), batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + results = self.model(batch) + boxes = self.model.inference_end(results, batch) + return { + DefaultDataKeys.INPUT: getattr(batch, "point", None), + DefaultDataKeys.PREDS: boxes, + DefaultDataKeys.METADATA: [a["name"] for a in batch.attr], + } + + def forward(self, x) -> torch.Tensor: + """First call the backbone, then the model head.""" + # hack to enable backbone to work properly. + self.model.device = self.device + return self.model(x) + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + + if not _POINTCLOUD_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") + + dataset.preprocess_fn = self.model.preprocess + dataset.transform_fn = self.model.transform + + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py new file mode 100644 index 0000000000..bddcfe7e41 --- /dev/null +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -0,0 +1,169 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch +from torch.utils.data.dataset import Dataset + +import flash +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + + from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer + from open3d.visualization import gui + + class Visualizer(Visualizer): + def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): + """Visualize a dataset. + + Example: + Minimal example for visualizing a dataset:: + import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d + + dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/') + vis = ml3d.vis.Visualizer() + vis.visualize_dataset(dataset, 'all', indices=range(100)) + + Args: + dataset: The dataset to use for visualization. + split: The dataset split to be used, such as 'training' + indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. + width: The width of the visualization window. + height: The height of the visualization window. + """ + # Setup the labels + lut = LabelLUT() + for id, color in dataset.color_map.items(): + lut.add_label(id, id, color=color) + self.set_lut("label", lut) + + self._consolidate_bounding_boxes = True + self._init_dataset(dataset, split, indices) + + self._visualize("Open3D - " + dataset.name, width, height) + + def _visualize(self, title, width, height): + gui.Application.instance.initialize() + self._init_user_interface(title, width, height) + + # override just to set background color to back :) + bgcolor = gui.ColorEdit() + bgcolor.color_value = gui.Color(0, 0, 0) + self._on_bgcolor_changed(bgcolor.color_value) + + self._3d.scene.downsample_threshold = 400000 + + # Turn all the objects off except the first one + for name, node in self._name2treenode.items(): + node.checkbox.checked = False + self._3d.scene.show_geometry(name, False) + for name in [self._objects.data_names[0]]: + self._name2treenode[name].checkbox.checked = True + self._3d.scene.show_geometry(name, True) + + def on_done_ui(): + # Add bounding boxes here: bounding boxes belonging to the dataset + # will not be loaded until now. + self._update_bounding_boxes() + + self._update_datasource_combobox() + self._update_shaders_combobox() + + # Display "colors" by default if available, "points" if not + available_attrs = self._get_available_attrs() + self._set_shader(self.SOLID_NAME, force_update=True) + if "colors" in available_attrs: + self._datasource_combobox.selected_text = "colors" + elif "points" in available_attrs: + self._datasource_combobox.selected_text = "points" + + self._dont_update_geometry = True + self._on_datasource_changed( + self._datasource_combobox.selected_text, self._datasource_combobox.selected_index + ) + self._update_geometry_colors() + self._dont_update_geometry = False + # _datasource_combobox was empty, now isn't, re-layout. + self.window.set_needs_layout() + + self._update_geometry() + self.setup_camera() + + self._load_geometries(self._objects.data_names, on_done_ui) + gui.Application.instance.run() + + class VizDataset(Dataset): + + name = "VizDataset" + + def __init__(self, dataset): + self.dataset = dataset + self.label_to_names = getattr(dataset, "label_to_names", {}) + self.path_list = getattr(dataset, "path_list", []) + self.color_map = getattr(dataset, "color_map", {}) + + def get_data(self, index): + data = self.dataset[index]["data"] + data["bounding_boxes"] = data["bbox_objs"] + data["color"] = np.ones_like(data["point"]) + return data + + def get_attr(self, index): + return self.dataset[index]["attr"] + + def get_split(self, *_) -> "VizDataset": + return self + + def __len__(self) -> int: + return len(self.dataset) + + class App: + def __init__(self, datamodule: DataModule): + self.datamodule = datamodule + self._enabled = not flash._IS_TESTING + + def get_dataset(self, stage: str = "train"): + dataloader = getattr(self.datamodule, f"{stage}_dataloader")() + return VizDataset(dataloader.dataset) + + def show_train_dataset(self, indices=None): + if self._enabled: + dataset = self.get_dataset("train") + viz = Visualizer() + viz.visualize_dataset(dataset, "all", indices=indices) + + def show_predictions(self, predictions): + if self._enabled: + dataset = self.get_dataset("train") + + viz = Visualizer() + lut = LabelLUT() + for id, color in dataset.color_map.items(): + lut.add_label(id, id, color=color) + viz.set_lut("label", lut) + + for pred in predictions: + data = { + "points": torch.stack(pred[DefaultDataKeys.INPUT])[:, :3], + "name": pred[DefaultDataKeys.METADATA], + } + bounding_box = pred[DefaultDataKeys.PREDS] + + viz.visualize([data], bounding_boxes=bounding_box) + + +def launch_app(datamodule: DataModule) -> "App": + return App(datamodule) diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py new file mode 100644 index 0000000000..759b6bdb43 --- /dev/null +++ b/flash/pointcloud/detection/open3d_ml/backbones.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC +from typing import Callable + +import torch +from pytorch_lightning.utilities.cloud_io import load as pl_load + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.providers import _OPEN3D_ML + +ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" + +if _POINTCLOUD_AVAILABLE: + import open3d + import open3d.ml as _ml3d + from open3d._ml3d.torch.dataloaders.concat_batcher import ConcatBatcher, ObjectDetectBatch + from open3d._ml3d.torch.models.point_pillars import PointPillars + from open3d.ml.torch.dataloaders import DefaultBatcher +else: + ObjectDetectBatch = ABC + PointPillars = ABC + + +class ObjectDetectBatchCollator(ObjectDetectBatch): + def __init__(self, batches): + self.num_batches = len(batches) + super().__init__(batches) + + def to(self, device): + super().to(device) + return self + + def __len__(self): + return self.num_batches + + +def register_open_3d_ml(register: FlashRegistry): + + if _POINTCLOUD_AVAILABLE: + + CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") + + def get_collate_fn(model) -> Callable: + batcher_name = model.cfg.batcher + if batcher_name == "DefaultBatcher": + batcher = DefaultBatcher() + elif batcher_name == "ConcatBatcher": + batcher = ConcatBatcher(torch, model.__class__.__name__) + elif batcher_name == "ObjectDetectBatchCollator": + return ObjectDetectBatchCollator + return batcher.collate_fn + + @register(parameters=PointPillars.__init__, providers=_OPEN3D_ML) + def pointpillars_kitti(*args, **kwargs) -> PointPillars: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml")) + cfg.model.device = "cpu" + model = PointPillars(**cfg.model) + weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth") + model.load_state_dict( + pl_load(weight_url, map_location="cpu")["model_state_dict"], + ) + model.cfg.batcher = "ObjectDetectBatchCollator" + return model, 384, get_collate_fn(model) + + @register(parameters=PointPillars.__init__, providers=_OPEN3D_ML) + def pointpillars(*args, **kwargs) -> PointPillars: + model = PointPillars(*args, **kwargs) + model.cfg.batcher = "ObjectDetectBatch" + return model, get_collate_fn(model) diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py new file mode 100644 index 0000000000..0c4872c3b3 --- /dev/null +++ b/flash/pointcloud/detection/open3d_ml/data_sources.py @@ -0,0 +1,241 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from os.path import basename, dirname, exists, isdir, isfile, join +from typing import Any, Dict, List, Optional, Union + +import yaml +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import BaseDataFormat, DataSource +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + from open3d._ml3d.datasets.kitti import DataProcessing, KITTI + + +class PointCloudObjectDetectionDataFormat(BaseDataFormat): + KITTI = "kitti" + + +class BasePointCloudObjectDetectorLoader: + + pass + + +class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader): + def __init__( + self, + image_size: tuple = (375, 1242), + scans_folder_name: Optional[str] = "scans", + labels_folder_name: Optional[str] = "labels", + calibrations_folder_name: Optional[str] = "calibs", + **kwargs, + ): + + self.image_size = image_size + self.scans_folder_name = scans_folder_name + self.labels_folder_name = labels_folder_name + self.calibrations_folder_name = calibrations_folder_name + + def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]): + meta_file = join(root_dir, "meta.yaml") + if not exists(meta_file): + raise MisconfigurationException(f"The {root_dir} should contain a `meta.yaml` file about the classes.") + + with open(meta_file) as f: + self.meta = yaml.safe_load(f) + + if "label_to_names" not in self.meta: + raise MisconfigurationException( + f"The {root_dir} should contain a `meta.yaml` file about the classes with the field `label_to_names`." + ) + + dataset.num_classes = len(self.meta["label_to_names"]) + dataset.label_to_names = self.meta["label_to_names"] + dataset.color_map = self.meta["color_map"] + + def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]): + sub_directories = os.listdir(folder) + if len(sub_directories) != 3: + raise MisconfigurationException( + f"Using KITTI Format, the {folder} should contains 3 directories " + "for ``calibrations``, ``labels`` and ``scans``." + ) + + assert self.scans_folder_name in sub_directories + assert self.labels_folder_name in sub_directories + assert self.calibrations_folder_name in sub_directories + + scans_dir = join(folder, self.scans_folder_name) + labels_dir = join(folder, self.labels_folder_name) + calibrations_dir = join(folder, self.calibrations_folder_name) + + scan_paths = [join(scans_dir, f) for f in os.listdir(scans_dir)] + label_paths = [join(labels_dir, f) for f in os.listdir(labels_dir)] + calibration_paths = [join(calibrations_dir, f) for f in os.listdir(calibrations_dir)] + + assert len(scan_paths) == len(label_paths) == len(calibration_paths) + + self.load_meta(dirname(folder), dataset) + + dataset.path_list = scan_paths + + return [ + {"scan_path": scan_path, "label_path": label_path, "calibration_path": calibration_path} + for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths) + ] + + def load_sample( + self, sample: Dict[str, str], dataset: Optional[BaseAutoDataset] = None, has_label: bool = True + ) -> Any: + pc = KITTI.read_lidar(sample["scan_path"]) + calib = KITTI.read_calib(sample["calibration_path"]) + label = None + if has_label: + label = KITTI.read_label(sample["label_path"], calib) + + reduced_pc = DataProcessing.remove_outside_points(pc, calib["world_cam"], calib["cam_img"], self.image_size) + + attr = { + "name": basename(sample["scan_path"]), + "path": sample["scan_path"], + "calibration_path": sample["calibration_path"], + "label_path": sample["label_path"] if has_label else None, + "split": "val", + } + + data = { + "point": reduced_pc, + "full_point": pc, + "feat": None, + "calib": calib, + "bounding_boxes": label if has_label else None, + "attr": attr, + } + return data, attr + + def load_files(self, scan_paths: Union[str, List[str]], dataset: Optional[BaseAutoDataset] = None): + if isinstance(scan_paths, str): + scan_paths = [scan_paths] + + def clean_fn(path: str) -> str: + return path.replace(self.scans_folder_name, self.calibrations_folder_name).replace(".bin", ".txt") + + dataset.path_list = scan_paths + + return [{"scan_path": scan_path, "calibration_path": clean_fn(scan_path)} for scan_path in scan_paths] + + def predict_load_data(self, data, dataset: Optional[BaseAutoDataset] = None): + if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)): + return self.load_files(data, dataset) + elif isinstance(data, str) and isdir(data): + raise NotImplementedError + + def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None): + data, attr = self.load_sample(data, dataset, has_label=False) + # hack to prevent manipulation of labels + attr["split"] = "test" + return data, attr + + +class PointCloudObjectDetectorFoldersDataSource(DataSource): + def __init__( + self, + data_format: Optional[BaseDataFormat] = None, + image_size: tuple = (375, 1242), + **loader_kwargs, + ): + super().__init__() + + self.loaders = { + PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader( + **loader_kwargs, image_size=image_size + ) + } + + self.data_format = data_format or PointCloudObjectDetectionDataFormat.KITTI + self.loader = self.loaders[self.data_format] + + def _validate_data(self, folder: str) -> None: + msg = f"The provided dataset for stage {self._running_stage} should be a folder. Found {folder}." + if not isinstance(folder, str): + raise MisconfigurationException(msg) + + if isinstance(folder, str) and not isdir(folder): + raise MisconfigurationException(msg) + + def load_data( + self, + data: Any, + dataset: Optional[BaseAutoDataset] = None, + ) -> Any: + + self._validate_data(data) + + return self.loader.load_data(data, dataset) + + def load_sample(self, metadata: Dict[str, str], dataset: Optional[BaseAutoDataset] = None) -> Any: + + data, metadata = self.loader.load_sample(metadata, dataset) + + preprocess_fn = getattr(dataset, "preprocess_fn", None) + if preprocess_fn: + data = preprocess_fn(data, metadata) + + transform_fn = getattr(dataset, "transform_fn", None) + if transform_fn: + data = transform_fn(data, metadata) + + return {"data": data, "attr": metadata} + + def _validate_predict_data(self, data: Union[str, List[str]]) -> None: + msg = f"The provided predict data should be a either a folder or a single/list of scan path(s). Found {data}." + if not isinstance(data, str) and not isinstance(data, list): + raise MisconfigurationException(msg) + + if isinstance(data, str) and (not isfile(data) or not isdir(data)): + raise MisconfigurationException(msg) + + if isinstance(data, list) and not all(isfile(p) for p in data): + raise MisconfigurationException(msg) + + def predict_load_data( + self, + data: Any, + dataset: Optional[BaseAutoDataset] = None, + ) -> Any: + + self._validate_predict_data(data) + + return self.loader.predict_load_data(data, dataset) + + def predict_load_sample( + self, + metadata: Any, + dataset: Optional[BaseAutoDataset] = None, + ) -> Any: + + data, metadata = self.loader.predict_load_sample(metadata, dataset) + + preprocess_fn = getattr(dataset, "preprocess_fn", None) + if preprocess_fn: + data = preprocess_fn(data, metadata) + + transform_fn = getattr(dataset, "transform_fn", None) + if transform_fn: + data = transform_fn(data, metadata) + + return {"data": data, "attr": metadata} diff --git a/flash/pointcloud/segmentation/__init__.py b/flash/pointcloud/segmentation/__init__.py new file mode 100644 index 0000000000..5d10606f79 --- /dev/null +++ b/flash/pointcloud/segmentation/__init__.py @@ -0,0 +1,3 @@ +from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401 +from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401 +from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401 diff --git a/flash/pointcloud/segmentation/backbones.py b/flash/pointcloud/segmentation/backbones.py new file mode 100644 index 0000000000..023daa9ac0 --- /dev/null +++ b/flash/pointcloud/segmentation/backbones.py @@ -0,0 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import FlashRegistry +from flash.pointcloud.segmentation.open3d_ml.backbones import register_open_3d_ml + +POINTCLOUD_SEGMENTATION_BACKBONES = FlashRegistry("backbones") + +register_open_3d_ml(POINTCLOUD_SEGMENTATION_BACKBONES) diff --git a/flash/pointcloud/segmentation/cli.py b/flash/pointcloud/segmentation/cli.py new file mode 100644 index 0000000000..57d1125f9b --- /dev/null +++ b/flash/pointcloud/segmentation/cli.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData + +__all__ = ["pointcloud_segmentation"] + + +def from_kitti( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> PointCloudSegmentationData: + """Downloads and loads the semantic KITTI data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") + return PointCloudSegmentationData.from_folders( + train_folder="data/SemanticKittiTiny/train", + val_folder="data/SemanticKittiTiny/val", + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def pointcloud_segmentation(): + """Segment objects in point clouds.""" + cli = FlashCLI( + PointCloudSegmentation, + PointCloudSegmentationData, + default_datamodule_builder=from_kitti, + default_arguments={ + "trainer.max_epochs": 3, + "model.backbone": "randlanet_semantic_kitti", + }, + finetune=False, + ) + + cli.trainer.save_checkpoint("pointcloud_segmentation_model.pt") + + +if __name__ == "__main__": + pointcloud_segmentation() diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py new file mode 100644 index 0000000000..92cd2cdbc2 --- /dev/null +++ b/flash/pointcloud/segmentation/data.py @@ -0,0 +1,93 @@ +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import requires_extras +from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset + + +class PointCloudSegmentationDatasetDataSource(DataSource): + def load_data( + self, + data: Any, + dataset: Optional[Any] = None, + ) -> Any: + if self.training: + dataset.num_classes = len(data.dataset.label_to_names) + + dataset.dataset = data + + return range(len(data)) + + def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: + sample = dataset.dataset[index] + + return { + DefaultDataKeys.INPUT: sample["data"], + DefaultDataKeys.METADATA: sample["attr"], + } + + +class PointCloudSegmentationFoldersDataSource(DataSource): + @requires_extras("pointcloud") + def load_data( + self, + folder: Any, + dataset: Optional[Any] = None, + ) -> Any: + sequence_dataset = SequencesDataset(folder, use_cache=True, predicting=self.predicting) + dataset.dataset = sequence_dataset + if self.training: + dataset.num_classes = sequence_dataset.num_classes + + return range(len(sequence_dataset)) + + def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: + sample = dataset.dataset[index] + + return { + DefaultDataKeys.INPUT: sample["data"], + DefaultDataKeys.METADATA: sample["attr"], + } + + +class PointCloudSegmentationPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + deserializer: Optional[Deserializer] = None, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATASETS: PointCloudSegmentationDatasetDataSource(), + DefaultDataSources.FOLDERS: PointCloudSegmentationFoldersDataSource(), + }, + deserializer=deserializer, + default_data_source=DefaultDataSources.FOLDERS, + ) + + def get_state_dict(self): + return {} + + def state_dict(self): + return {} + + @classmethod + def load_state_dict(cls, state_dict, strict: bool = False): + pass + + +class PointCloudSegmentationData(DataModule): + + preprocess_cls = PointCloudSegmentationPreprocess diff --git a/flash/pointcloud/segmentation/datasets.py b/flash/pointcloud/segmentation/datasets.py new file mode 100644 index 0000000000..ff792282a4 --- /dev/null +++ b/flash/pointcloud/segmentation/datasets.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + from open3d.ml.datasets import Lyft, SemanticKITTI + +_SEGMENTATION_DATASET = FlashRegistry("dataset") + + +def executor(download_script, preprocess_script, dataset_path, name): + if not os.path.exists(os.path.join(dataset_path, name)): + os.system(f'bash -c "bash <(curl -s {download_script}) {dataset_path}"') + if preprocess_script: + os.system(f'bash -c "bash <(curl -s {preprocess_script}) {dataset_path}"') + + +@_SEGMENTATION_DATASET +def lyft(dataset_path): + name = "Lyft" + executor( + "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_lyft.sh", + "https://github.com/intel-isl/Open3D-ML/blob/master/scripts/preprocess_lyft.py", + dataset_path, + name, + ) + return Lyft(os.path.join(dataset_path, name)) + + +def LyftDataset(dataset_path): + return _SEGMENTATION_DATASET.get("lyft")(dataset_path) + + +@_SEGMENTATION_DATASET +def semantickitti(dataset_path, download, **kwargs): + name = "SemanticKitti" + if download: + executor( + "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_semantickitti.sh", # noqa E501 + None, + dataset_path, + name, + ) + return SemanticKITTI(os.path.join(dataset_path, name), **kwargs) + + +def SemanticKITTIDataset(dataset_path, download: bool = True, **kwargs): + return _SEGMENTATION_DATASET.get("semantickitti")(dataset_path, download, **kwargs) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py new file mode 100644 index 0000000000..9342a61758 --- /dev/null +++ b/flash/pointcloud/segmentation/model.py @@ -0,0 +1,221 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +import torchmetrics +from pytorch_lightning import Callback, LightningModule +from torch import nn +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, Sampler +from torchmetrics import IoU + +from flash.core.classification import ClassificationTask +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer +from flash.core.data.states import CollateFn +from flash.core.finetuning import BaseFinetuning +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES + +if _POINTCLOUD_AVAILABLE: + from open3d._ml3d.torch.modules.losses.semseg_loss import filter_valid_label + from open3d.ml.torch.dataloaders import TorchDataloader + + +class PointCloudSegmentationFinetuning(BaseFinetuning): + def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1): + super().__init__() + self.num_layers = num_layers + self.train_bn = train_bn + self.unfreeze_epoch = unfreeze_epoch + + def freeze_before_training(self, pl_module: LightningModule) -> None: + self.freeze(modules=list(pl_module.backbone.children())[: -self.num_layers], train_bn=self.train_bn) + + def finetune_function( + self, + pl_module: LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + if epoch != self.unfreeze_epoch: + return + self.unfreeze_and_add_param_group( + modules=list(pl_module.backbone.children())[-self.num_layers :], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +class PointCloudSegmentationSerializer(Serializer): + pass + + +class PointCloudSegmentation(ClassificationTask): + """The ``PointCloudClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies + pointcloud data. + + Args: + num_features: The number of features (elements) in the input data. + num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`. + backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use. + backbone_kwargs: Any additional kwargs to pass to the backbone constructor. + loss_fn: The loss function to use. If ``None``, a default will be selected by the + :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + optimizer: The optimizer or optimizer class to use. + optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). + scheduler: The scheduler or scheduler class to use. + scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected + by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + learning_rate: The learning rate for the optimizer. + multi_label: If ``True``, this will be treated as a multi-label classification problem. + serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + """ + + backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES + + required_extras: str = "pointcloud" + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "RandLANet", + backbone_kwargs: Optional[Dict] = None, + head: Optional[nn.Module] = None, + loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + multi_label: bool = False, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(), + ): + import flash + + if metrics is None: + metrics = IoU(num_classes=num_classes) + + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + multi_label=multi_label, + serializer=serializer, + ) + + self.save_hyperparameters() + + if not backbone_kwargs: + backbone_kwargs = {"num_classes": num_classes} + + if isinstance(backbone, tuple): + self.backbone, out_features = backbone + else: + self.backbone, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) + # replace latest layer + if not flash._IS_TESTING: + self.backbone.fc = nn.Identity() + self.set_state(CollateFn(collate_fn)) + + self.head = nn.Identity() if flash._IS_TESTING else (head or nn.Linear(out_features, num_classes)) + + def apply_filtering(self, labels, scores): + scores, labels = filter_valid_label(scores, labels, self.hparams.num_classes, [0], self.device) + return labels, scores + + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + return F.softmax(self.to_loss_format(x), dim=-1) + + def to_loss_format(self, x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, x.shape[-1]) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT]) + batch[DefaultDataKeys.TARGET] = batch[DefaultDataKeys.INPUT]["labels"] + # drop sub-sampled pointclouds + batch[DefaultDataKeys.INPUT] = batch[DefaultDataKeys.INPUT]["xyz"][0] + return batch + + def forward(self, x) -> torch.Tensor: + """First call the backbone, then the model head.""" + # hack to enable backbone to work properly. + self.backbone.device = self.device + x = self.backbone(x) + if self.head is not None: + x = self.head(x) + return x + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + ) -> DataLoader: + + if not _POINTCLOUD_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") + + if not isinstance(dataset.dataset, TorchDataloader): + + dataset.dataset = TorchDataloader( + dataset.dataset, + preprocess=self.backbone.preprocess, + transform=self.backbone.transform, + use_cache=False, + ) + + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + def configure_finetune_callback(self) -> List[Callback]: + return [PointCloudSegmentationFinetuning()] diff --git a/flash/pointcloud/segmentation/open3d_ml/__init__.py b/flash/pointcloud/segmentation/open3d_ml/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py new file mode 100644 index 0000000000..b1145c53b5 --- /dev/null +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -0,0 +1,107 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + + from open3d._ml3d.torch.dataloaders import TorchDataloader + from open3d._ml3d.vis.visualizer import LabelLUT + from open3d._ml3d.vis.visualizer import Visualizer as Open3dVisualizer + +else: + + Open3dVisualizer = object + + +class Visualizer(Open3dVisualizer): + def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): + """Visualize a dataset. + + Example: + Minimal example for visualizing a dataset:: + import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d + + dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/') + vis = ml3d.vis.Visualizer() + vis.visualize_dataset(dataset, 'all', indices=range(100)) + + Args: + dataset: The dataset to use for visualization. + split: The dataset split to be used, such as 'training' + indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. + width: The width of the visualization window. + height: The height of the visualization window. + """ + # Setup the labels + lut = LabelLUT() + color_map = dataset.color_map + for id, val in dataset.label_to_names.items(): + lut.add_label(val, id, color=color_map[id]) + self.set_lut("labels", lut) + + self._consolidate_bounding_boxes = True + self._init_dataset(dataset, split, indices) + self._visualize("Open3D - " + dataset.name, width, height) + + +class App: + def __init__(self, datamodule: DataModule): + self.datamodule = datamodule + self._enabled = True # not flash._IS_TESTING + + def get_dataset(self, stage: str = "train"): + dataloader = getattr(self.datamodule, f"{stage}_dataloader")() + dataset = dataloader.dataset.dataset + if isinstance(dataset, TorchDataloader): + return dataset.dataset + return dataset + + def show_train_dataset(self, indices=None): + if self._enabled: + dataset = self.get_dataset("train") + viz = Visualizer() + viz.visualize_dataset(dataset, "all", indices=indices) + + def show_predictions(self, predictions): + if self._enabled: + dataset = self.get_dataset("train") + color_map = dataset.color_map + + predictions_visualizations = [] + for pred in predictions: + predictions_visualizations.append( + { + "points": torch.stack(pred[DefaultDataKeys.INPUT]), + "labels": torch.stack(pred[DefaultDataKeys.TARGET]), + "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, + "name": pred[DefaultDataKeys.METADATA]["name"], + } + ) + + viz = Visualizer() + lut = LabelLUT() + color_map = dataset.color_map + for id, val in dataset.label_to_names.items(): + lut.add_label(val, id, color=color_map[id]) + viz.set_lut("labels", lut) + viz.set_lut("predictions", lut) + viz.visualize(predictions_visualizations) + + +def launch_app(datamodule: DataModule) -> "App": + return App(datamodule) diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py new file mode 100644 index 0000000000..a326cbcdc5 --- /dev/null +++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable + +import torch +from pytorch_lightning.utilities.cloud_io import load as pl_load + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.providers import _OPEN3D_ML + +ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" + + +def register_open_3d_ml(register: FlashRegistry): + if _POINTCLOUD_AVAILABLE: + import open3d + import open3d.ml as _ml3d + from open3d._ml3d.torch.dataloaders import ConcatBatcher, DefaultBatcher + from open3d._ml3d.torch.models import RandLANet + + CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") + + def get_collate_fn(model) -> Callable: + batcher_name = model.cfg.batcher + if batcher_name == "DefaultBatcher": + batcher = DefaultBatcher() + elif batcher_name == "ConcatBatcher": + batcher = ConcatBatcher(torch, model.__class__.__name__) + else: + batcher = None + return batcher.collate_fn + + @register(providers=_OPEN3D_ML) + def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml")) + model = RandLANet(**cfg.model) + if use_fold_5: + weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth") + else: + weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth") + model.load_state_dict(pl_load(weight_url, map_location="cpu")["model_state_dict"]) + return model, 32, get_collate_fn(model) + + @register(providers=_OPEN3D_ML) + def randlanet_toronto3d(*args, **kwargs) -> RandLANet: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml")) + model = RandLANet(**cfg.model) + model.load_state_dict( + pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"), map_location="cpu")[ + "model_state_dict" + ], + ) + return model, 32, get_collate_fn(model) + + @register(providers=_OPEN3D_ML) + def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml")) + model = RandLANet(**cfg.model) + model.load_state_dict( + pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"), map_location="cpu")[ + "model_state_dict" + ], + ) + return model, 32, get_collate_fn(model) + + @register(providers=_OPEN3D_ML) + def randlanet(*args, **kwargs) -> RandLANet: + model = RandLANet(*args, **kwargs) + return model, 32, get_collate_fn(model) diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py new file mode 100644 index 0000000000..966b224c78 --- /dev/null +++ b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -0,0 +1,182 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from os.path import basename, dirname, exists, isdir, isfile, join, split + +import numpy as np +import yaml +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data import Dataset + +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + + from open3d._ml3d.datasets.utils import DataProcessing + from open3d._ml3d.utils.config import Config + + +class SequencesDataset(Dataset): + def __init__( + self, + data, + cache_dir="./logs/cache", + use_cache=False, + num_points=65536, + ignored_label_inds=[0], + predicting=False, + **kwargs, + ): + + super().__init__() + + self.name = "Dataset" + self.ignored_label_inds = ignored_label_inds + + kwargs["cache_dir"] = cache_dir + kwargs["use_cache"] = use_cache + kwargs["num_points"] = num_points + kwargs["ignored_label_inds"] = ignored_label_inds + + self.cfg = Config(kwargs) + self.predicting = predicting + + if not predicting: + self.on_fit(data) + else: + self.on_predict(data) + + @property + def color_map(self): + return self.meta["color_map"] + + def on_fit(self, dataset_path): + self.split = basename(dataset_path) + + self.load_meta(dirname(dataset_path)) + self.dataset_path = dataset_path + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) - len(self.ignored_label_inds) + self.make_datasets() + + def load_meta(self, root_dir): + meta_file = join(root_dir, "meta.yaml") + if not exists(meta_file): + raise MisconfigurationException( + f"The {root_dir} should contain a `meta.yaml` file about the pointcloud sequences." + ) + + with open(meta_file) as f: + self.meta = yaml.safe_load(f) + + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + + with open(meta_file) as f: + self.meta = yaml.safe_load(f) + + remap_dict_val = self.meta["learning_map"] + max_key = max(remap_dict_val.keys()) + remap_lut_val = np.zeros((max_key + 100), dtype=np.int32) + remap_lut_val[list(remap_dict_val.keys())] = list(remap_dict_val.values()) + + self.remap_lut_val = remap_lut_val + + def make_datasets(self): + self.path_list = [] + for seq in os.listdir(self.dataset_path): + sequence_path = join(self.dataset_path, seq) + directories = [f for f in os.listdir(sequence_path) if isdir(join(sequence_path, f)) and f != "labels"] + assert len(directories) == 1 + scan_dir = join(sequence_path, directories[0]) + for scan_name in os.listdir(scan_dir): + self.path_list.append(join(scan_dir, scan_name)) + + def on_predict(self, data): + if isinstance(data, list): + if not all(isfile(p) for p in data): + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + root_dir = split(data[0])[0] + elif isinstance(data, str): + if not isdir(data) and not isfile(data): + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + if isdir(data): + root_dir = data + data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if ".bin" in f] + elif isfile(data): + root_dir = dirname(data) + data = [data] + else: + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + else: + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + + self.path_list = data + self.split = "predict" + self.load_meta(root_dir) + + def get_label_to_names(self): + """Returns a label to names dictonary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + return self.meta["label_to_names"] + + def __getitem__(self, index): + data = self.get_data(index) + data["attr"] = self.get_attr(index) + return data + + def get_data(self, idx): + pc_path = self.path_list[idx] + points = DataProcessing.load_pc_kitti(pc_path) + + dir, file = split(pc_path) + if self.predicting: + label_path = join(dir, file[:-4] + ".label") + else: + label_path = join(dir, "../labels", file[:-4] + ".label") + if not exists(label_path): + labels = np.zeros(np.shape(points)[0], dtype=np.int32) + if self.split not in ["test", "all"]: + raise FileNotFoundError(f" Label file {label_path} not found") + + else: + labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32) + + data = { + "point": points[:, 0:3], + "feat": None, + "label": labels, + } + + return data + + def get_attr(self, idx): + pc_path = self.path_list[idx] + dir, file = split(pc_path) + _, seq = split(split(dir)[0]) + name = f"{seq}_{file[:-4]}" + + pc_path = str(pc_path) + attr = {"idx": idx, "name": name, "path": pc_path, "split": self.split} + return attr + + def __len__(self): + return len(self.path_list) + + def get_split(self, *_): + return self diff --git a/flash/setup_tools.py b/flash/setup_tools.py index 8e27bf2c1c..a7376eb940 100644 --- a/flash/setup_tools.py +++ b/flash/setup_tools.py @@ -19,17 +19,17 @@ _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) -def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_chars: str = '#@') -> List[str]: - with open(os.path.join(path_dir, file_name), 'r') as file: +def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_chars: str = "#@") -> List[str]: + with open(os.path.join(path_dir, file_name)) as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: # filer all comments found = [ln.index(ch) for ch in comment_chars if ch in ln] if found: - ln = ln[:min(found)].strip() + ln = ln[: min(found)].strip() # skip directly installed dependencies - if ln.startswith('http') or ln.startswith('git'): + if ln.startswith("http") or ln.startswith("git"): continue if ln: # if requirement is not empty reqs.append(ln) @@ -37,7 +37,7 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: - """Load readme as decribtion + """Load readme as decribtion. >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' @@ -46,7 +46,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: text = open(path_readme, encoding="utf-8").read() # drop images from readme - text = text.replace('![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)', '') + text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "") # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png github_source_url = os.path.join(homepage, "raw", ver) @@ -55,17 +55,17 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") # readthedocs badge - text = text.replace('badge/?version=stable', f'badge/?version={ver}') - text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{ver}') + text = text.replace("badge/?version=stable", f"badge/?version={ver}") + text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{ver}") # codecov badge - text = text.replace('/branch/master/graph/badge.svg', f'/release/{ver}/graph/badge.svg') + text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg") # replace github badges for release ones - text = text.replace('badge.svg?branch=master&event=push', f'badge.svg?tag={ver}') + text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") - skip_begin = r'' - skip_end = r'' + skip_begin = r"" + skip_end = r"" # todo: wrap content as commented description - text = re.sub(rf"{skip_begin}.+?{skip_end}", '', text, flags=re.IGNORECASE + re.DOTALL) + text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png # github_release_url = os.path.join(homepage, "releases", "download", ver) diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index a3b8e2ca2d..22698efc99 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1 +1,3 @@ -from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401 +from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401 +from flash.tabular.data import TabularData # noqa: F401 +from flash.tabular.regression import TabularRegressionData # noqa: F401 diff --git a/flash/tabular/classification/__init__.py b/flash/tabular/classification/__init__.py index 45724db27b..6134277abf 100644 --- a/flash/tabular/classification/__init__.py +++ b/flash/tabular/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.tabular.classification.data import TabularData # noqa: F401 +from flash.tabular.classification.data import TabularClassificationData # noqa: F401 from flash.tabular.classification.model import TabularClassifier # noqa: F401 diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py new file mode 100644 index 0000000000..63eff2458f --- /dev/null +++ b/flash/tabular/classification/cli.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.tabular import TabularClassificationData, TabularClassifier + +__all__ = ["tabular_classification"] + + +def from_titanic( + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> TabularClassificationData: + """Downloads and loads the Titanic data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") + return TabularClassificationData.from_csv( + ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + "Fare", + target_fields="Survived", + train_file="data/titanic/titanic.csv", + val_split=0.1, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def tabular_classification(): + """Classify tabular data.""" + cli = FlashCLI( + TabularClassifier, + TabularClassificationData, + default_datamodule_builder=from_titanic, + default_arguments={ + "trainer.max_epochs": 3, + }, + finetune=False, + datamodule_attributes={"num_features", "num_classes", "embedding_sizes"}, + ) + + cli.trainer.save_checkpoint("tabular_classification_model.pt") + + +if __name__ == "__main__": + tabular_classification() diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index c2a60e24da..63cdda9ea2 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -11,505 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from flash.tabular.data import TabularData -import numpy as np -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash.core.classification import LabelsState -from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Postprocess, Preprocess -from flash.core.utilities.imports import _PANDAS_AVAILABLE -from flash.tabular.classification.utils import ( - _compute_normalization, - _generate_codes, - _pre_transform, - _to_cat_vars_numpy, - _to_num_vars_numpy, -) - -if _PANDAS_AVAILABLE: - import pandas as pd - from pandas.core.frame import DataFrame -else: - DataFrame = object - - -class TabularDataFrameDataSource(DataSource[DataFrame]): - - def __init__( - self, - cat_cols: Optional[List[str]] = None, - num_cols: Optional[List[str]] = None, - target_col: Optional[str] = None, - mean: Optional[DataFrame] = None, - std: Optional[DataFrame] = None, - codes: Optional[Dict[str, Any]] = None, - target_codes: Optional[Dict[str, Any]] = None, - classes: Optional[List[str]] = None, - is_regression: bool = True, - ): - super().__init__() - - self.cat_cols = cat_cols - self.num_cols = num_cols - self.target_col = target_col - self.mean = mean - self.std = std - self.codes = codes - self.target_codes = target_codes - self.is_regression = is_regression - - self.set_state(LabelsState(classes)) - self.num_classes = len(classes) - - def common_load_data( - self, - df: DataFrame, - dataset: Optional[Any] = None, - ): - # impute_data - # compute train dataset stats - dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, - self.target_codes) - - df = dfs[0] - - if dataset is not None: - dataset.num_samples = len(df) - - cat_vars = _to_cat_vars_numpy(df, self.cat_cols) - num_vars = _to_num_vars_numpy(df, self.num_cols) - - cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0)) - num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0)) - return df, cat_vars, num_vars - - def load_data(self, data: DataFrame, dataset: Optional[Any] = None): - df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) - target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) - return [{ - DefaultDataKeys.INPUT: (c, n), - DefaultDataKeys.TARGET: t - } for c, n, t in zip(cat_vars, num_vars, target)] - - def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): - _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) - return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] - - -class TabularCSVDataSource(TabularDataFrameDataSource): - - def load_data(self, data: str, dataset: Optional[Any] = None): - return super().load_data(pd.read_csv(data), dataset=dataset) - - def predict_load_data(self, data: str, dataset: Optional[Any] = None): - return super().predict_load_data(pd.read_csv(data), dataset=dataset) - - -class TabularDeserializer(Deserializer): - - def __init__( - self, - cat_cols: Optional[List[str]] = None, - num_cols: Optional[List[str]] = None, - target_col: Optional[str] = None, - mean: Optional[DataFrame] = None, - std: Optional[DataFrame] = None, - codes: Optional[Dict[str, Any]] = None, - target_codes: Optional[Dict[str, Any]] = None, - classes: Optional[List[str]] = None, - is_regression: bool = True - ): - super().__init__() - self.cat_cols = cat_cols - self.num_cols = num_cols - self.target_col = target_col - self.mean = mean - self.std = std - self.codes = codes - self.target_codes = target_codes - self.classes = classes - self.is_regression = is_regression - - def deserialize(self, data: str) -> Any: - df = pd.read_csv(StringIO(data)) - df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, - self.target_codes)[0] - - cat_vars = _to_cat_vars_numpy(df, self.cat_cols) - num_vars = _to_num_vars_numpy(df, self.num_cols) - - cat_vars = np.stack(cat_vars, 1) - num_vars = np.stack(num_vars, 1) - - return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] - - @property - def example_input(self) -> str: - row = {} - for cat_col in self.cat_cols: - row[cat_col] = ["test"] - for num_col in self.num_cols: - row[num_col] = [0] - return str(DataFrame.from_dict(row).to_csv()) - - -class TabularPreprocess(Preprocess): - - def __init__( - self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - cat_cols: Optional[List[str]] = None, - num_cols: Optional[List[str]] = None, - target_col: Optional[str] = None, - mean: Optional[DataFrame] = None, - std: Optional[DataFrame] = None, - codes: Optional[Dict[str, Any]] = None, - target_codes: Optional[Dict[str, Any]] = None, - classes: Optional[List[str]] = None, - is_regression: bool = True, - deserializer: Optional[Deserializer] = None - ): - self.cat_cols = cat_cols - self.num_cols = num_cols - self.target_col = target_col - self.mean = mean - self.std = std - self.codes = codes - self.target_codes = target_codes - self.classes = classes - self.is_regression = is_regression - - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: TabularCSVDataSource( - cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression - ), - "data_frame": TabularDataFrameDataSource( - cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression - ), - }, - default_data_source=DefaultDataSources.CSV, - deserializer=deserializer or TabularDeserializer( - cat_cols=cat_cols, - num_cols=num_cols, - target_col=target_col, - mean=mean, - std=std, - codes=codes, - target_codes=target_codes, - classes=classes, - is_regression=is_regression - ) - ) - - def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: - return { - **self.transforms, - "cat_cols": self.cat_cols, - "num_cols": self.num_cols, - "target_col": self.target_col, - "mean": self.mean, - "std": self.std, - "codes": self.codes, - "target_codes": self.target_codes, - "classes": self.classes, - "is_regression": self.is_regression, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': - return cls(**state_dict) - - -class TabularPostprocess(Postprocess): - - def uncollate(self, batch: Any) -> Any: - return batch - - -class TabularData(DataModule): - """Data module for tabular tasks""" - - preprocess_cls = TabularPreprocess - postprocess_cls = TabularPostprocess - - @property - def codes(self) -> Dict[str, str]: - return self._data_source.codes - - @property - def num_classes(self) -> int: - return self._data_source.num_classes - - @property - def cat_cols(self) -> Optional[List[str]]: - return self._data_source.cat_cols - - @property - def num_cols(self) -> Optional[List[str]]: - return self._data_source.num_cols - - @property - def num_features(self) -> int: - return len(self.cat_cols) + len(self.num_cols) - - @property - def emb_sizes(self) -> list: - """Recommended embedding sizes.""" - - # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html - # The following "formula" provides a general rule of thumb about the number of embedding dimensions: - # embedding_dimensions = number_of_categories**0.25 - num_classes = [len(self.codes[cat]) for cat in self.cat_cols] - emb_dims = [max(int(n**0.25), 16) for n in num_classes] - return list(zip(num_classes, emb_dims)) - - @staticmethod - def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]): - if cat_cols is None and num_cols is None: - raise RuntimeError('Both `cat_cols` and `num_cols` are None!') - - return cat_cols or [], num_cols or [] - - @classmethod - def compute_state( - cls, - train_data_frame: DataFrame, - val_data_frame: Optional[DataFrame], - test_data_frame: Optional[DataFrame], - predict_data_frame: Optional[DataFrame], - target_fields: str, - numerical_fields: List[str], - categorical_fields: List[str], - ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: - - if train_data_frame is None: - raise MisconfigurationException( - "train_data_frame is required to instantiate the TabularDataFrameDataSource" - ) - - data_frames = [train_data_frame] - - if val_data_frame is not None: - data_frames += [val_data_frame] - - if test_data_frame is not None: - data_frames += [test_data_frame] - - if predict_data_frame is not None: - data_frames += [predict_data_frame] - - mean, std = _compute_normalization(data_frames[0], numerical_fields) - - classes = list(data_frames[0][target_fields].unique()) - - if data_frames[0][target_fields].dtype == object: - # if the target_fields is a category, not an int - target_codes = _generate_codes(data_frames, [target_fields]) - else: - target_codes = None - codes = _generate_codes(data_frames, categorical_fields) - - return mean, std, classes, codes, target_codes - - @classmethod - def from_data_frame( - cls, - categorical_fields: Optional[Union[str, List[str]]], - numerical_fields: Optional[Union[str, List[str]]], - target_fields: Optional[str] = None, - train_data_frame: Optional[DataFrame] = None, - val_data_frame: Optional[DataFrame] = None, - test_data_frame: Optional[DataFrame] = None, - predict_data_frame: Optional[DataFrame] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - is_regression: bool = False, - **preprocess_kwargs: Any, - ): - """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. - - Args: - categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. - numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. - target_fields: The field or fields (columns) in the CSV file to use for the target. - train_data_frame: The pandas ``DataFrame`` containing the training data. - val_data_frame: The pandas ``DataFrame`` containing the validation data. - test_data_frame: The pandas ``DataFrame`` containing the testing data. - predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` - will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be - formatted as integers. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. - - Returns: - The constructed data module. - - Examples:: - - data_module = TabularData.from_data_frame( - "categorical_input", - "numerical_input", - "target", - train_data_frame=train_data, - ) - """ - categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields) - - if not isinstance(categorical_fields, list): - categorical_fields = [categorical_fields] - - if not isinstance(numerical_fields, list): - numerical_fields = [numerical_fields] - - mean, std, classes, codes, target_codes = cls.compute_state( - train_data_frame=train_data_frame, - val_data_frame=val_data_frame, - test_data_frame=test_data_frame, - predict_data_frame=predict_data_frame, - target_fields=target_fields, - numerical_fields=numerical_fields, - categorical_fields=categorical_fields, - ) - - return cls.from_data_source( - "data_frame", - train_data_frame, - val_data_frame, - test_data_frame, - predict_data_frame, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_fetcher=data_fetcher, - preprocess=preprocess, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - cat_cols=categorical_fields, - num_cols=numerical_fields, - target_col=target_fields, - mean=mean, - std=std, - codes=codes, - target_codes=target_codes, - classes=classes, - is_regression=is_regression, - **preprocess_kwargs, - ) - - @classmethod - def from_csv( - cls, - categorical_fields: Optional[Union[str, List[str]]], - numerical_fields: Optional[Union[str, List[str]]], - target_fields: Optional[str] = None, - train_file: Optional[str] = None, - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - is_regression: bool = False, - **preprocess_kwargs: Any, - ) -> 'DataModule': - """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. - - Args: - categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. - numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. - target_fields: The field or fields (columns) in the CSV file to use for the target. - train_file: The CSV file containing the training data. - val_file: The CSV file containing the validation data. - test_file: The CSV file containing the testing data. - predict_file: The CSV file containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` - will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be - formatted as integers. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. - - Returns: - The constructed data module. - - Examples:: - - data_module = TabularData.from_csv( - "categorical_input", - "numerical_input", - "target", - train_file="train_data.csv", - ) - """ - return cls.from_data_frame( - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - train_data_frame=pd.read_csv(train_file) if train_file is not None else None, - val_data_frame=pd.read_csv(val_file) if val_file is not None else None, - test_data_frame=pd.read_csv(test_file) if test_file is not None else None, - predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, - is_regression=is_regression, - preprocess=preprocess, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - ) +class TabularClassificationData(TabularData): + is_regression = False diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 3106bd57c9..b01e99e4f6 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -53,7 +53,7 @@ def __init__( self, num_features: int, num_classes: int, - embedding_sizes: List[Tuple] = None, + embedding_sizes: List[Tuple[int, int]] = None, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -71,7 +71,7 @@ def __init__( cat_idxs=list(range(len(embedding_sizes))), cat_dims=list(cat_dims), cat_emb_dim=list(cat_emb_dim), - **tabnet_kwargs + **tabnet_kwargs, ) super().__init__( @@ -108,17 +108,15 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) + batch = batch[DefaultDataKeys.INPUT] return self(batch) @classmethod - def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier': - model = cls(datamodule.num_features, datamodule.num_classes, datamodule.emb_sizes, **kwargs) + def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": + model = cls(datamodule.num_features, datamodule.num_classes, datamodule.embedding_sizes, **kwargs) return model @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ - assert history[-1]["val_accuracy"] > 0.65 + """This function is used only for debugging usage with CI.""" + assert history[-1]["val_accuracy"] > 0.6, history[-1]["val_accuracy"] diff --git a/flash/tabular/data.py b/flash/tabular/data.py new file mode 100644 index 0000000000..da36d726ce --- /dev/null +++ b/flash/tabular/data.py @@ -0,0 +1,509 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.classification import LabelsState +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.utilities.imports import _PANDAS_AVAILABLE +from flash.tabular.classification.utils import ( + _compute_normalization, + _generate_codes, + _pre_transform, + _to_cat_vars_numpy, + _to_num_vars_numpy, +) + +if _PANDAS_AVAILABLE: + import pandas as pd + from pandas.core.frame import DataFrame +else: + DataFrame = object + + +class TabularDataFrameDataSource(DataSource[DataFrame]): + def __init__( + self, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + ): + super().__init__() + + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.is_regression = is_regression + + self.set_state(LabelsState(classes)) + self.num_classes = len(classes) + + def common_load_data( + self, + df: DataFrame, + dataset: Optional[Any] = None, + ): + # impute_data + # compute train dataset stats + dfs = _pre_transform( + [df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes + ) + + df = dfs[0] + + if dataset is not None: + dataset.num_samples = len(df) + + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) + num_vars = _to_num_vars_numpy(df, self.num_cols) + + cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0)) + num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0)) + return df, cat_vars, num_vars + + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): + df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) + target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) + return [ + {DefaultDataKeys.INPUT: (c, n), DefaultDataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, target) + ] + + def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): + _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) + return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] + + +class TabularCSVDataSource(TabularDataFrameDataSource): + def load_data(self, data: str, dataset: Optional[Any] = None): + return super().load_data(pd.read_csv(data), dataset=dataset) + + def predict_load_data(self, data: str, dataset: Optional[Any] = None): + return super().predict_load_data(pd.read_csv(data), dataset=dataset) + + +class TabularDeserializer(Deserializer): + def __init__( + self, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + ): + super().__init__() + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.classes = classes + self.is_regression = is_regression + + def deserialize(self, data: str) -> Any: + df = pd.read_csv(StringIO(data)) + df = _pre_transform( + [df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes + )[0] + + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) + num_vars = _to_num_vars_numpy(df, self.num_cols) + + cat_vars = np.stack(cat_vars, 1) + num_vars = np.stack(num_vars, 1) + + return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] + + @property + def example_input(self) -> str: + row = {} + for cat_col in self.cat_cols: + row[cat_col] = ["test"] + for num_col in self.num_cols: + row[num_col] = [0] + return str(DataFrame.from_dict(row).to_csv()) + + +class TabularPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + deserializer: Optional[Deserializer] = None, + ): + classes = classes or [] + + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.classes = classes + self.is_regression = is_regression + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: TabularCSVDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + "data_frame": TabularDataFrameDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + }, + default_data_source=DefaultDataSources.CSV, + deserializer=deserializer + or TabularDeserializer( + cat_cols=cat_cols, + num_cols=num_cols, + target_col=target_col, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=is_regression, + ), + ) + + def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: + return { + **self.transforms, + "cat_cols": self.cat_cols, + "num_cols": self.num_cols, + "target_col": self.target_col, + "mean": self.mean, + "std": self.std, + "codes": self.codes, + "target_codes": self.target_codes, + "classes": self.classes, + "is_regression": self.is_regression, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess": + return cls(**state_dict) + + +class TabularPostprocess(Postprocess): + def uncollate(self, batch: Any) -> Any: + return batch + + +class TabularData(DataModule): + """Data module for tabular tasks.""" + + preprocess_cls = TabularPreprocess + postprocess_cls = TabularPostprocess + + is_regression: bool = False + + @property + def codes(self) -> Dict[str, str]: + return self._data_source.codes + + @property + def num_classes(self) -> int: + return self._data_source.num_classes + + @property + def cat_cols(self) -> Optional[List[str]]: + return self._data_source.cat_cols + + @property + def num_cols(self) -> Optional[List[str]]: + return self._data_source.num_cols + + @property + def num_features(self) -> int: + return len(self.cat_cols) + len(self.num_cols) + + @property + def embedding_sizes(self) -> list: + """Recommended embedding sizes.""" + + # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html + # The following "formula" provides a general rule of thumb about the number of embedding dimensions: + # embedding_dimensions = number_of_categories**0.25 + num_classes = [len(self.codes[cat]) for cat in self.cat_cols] + emb_dims = [max(int(n ** 0.25), 16) for n in num_classes] + return list(zip(num_classes, emb_dims)) + + @staticmethod + def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]): + if cat_cols is None and num_cols is None: + raise RuntimeError("Both `cat_cols` and `num_cols` are None!") + + return cat_cols or [], num_cols or [] + + @classmethod + def compute_state( + cls, + train_data_frame: DataFrame, + val_data_frame: Optional[DataFrame], + test_data_frame: Optional[DataFrame], + predict_data_frame: Optional[DataFrame], + target_fields: str, + numerical_fields: List[str], + categorical_fields: List[str], + ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: + + if train_data_frame is None: + raise MisconfigurationException( + "train_data_frame is required to instantiate the TabularDataFrameDataSource" + ) + + data_frames = [train_data_frame] + + if val_data_frame is not None: + data_frames += [val_data_frame] + + if test_data_frame is not None: + data_frames += [test_data_frame] + + if predict_data_frame is not None: + data_frames += [predict_data_frame] + + mean, std = _compute_normalization(data_frames[0], numerical_fields) + + classes = list(data_frames[0][target_fields].unique()) + + if data_frames[0][target_fields].dtype == object: + # if the target_fields is a category, not an int + target_codes = _generate_codes(data_frames, [target_fields]) + else: + target_codes = None + codes = _generate_codes(data_frames, categorical_fields) + + return mean, std, classes, codes, target_codes + + @classmethod + def from_data_frame( + cls, + categorical_fields: Optional[Union[str, List[str]]], + numerical_fields: Optional[Union[str, List[str]]], + target_fields: Optional[str] = None, + train_data_frame: Optional[DataFrame] = None, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. + + Args: + categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. + numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularData.from_data_frame( + "categorical_input", + "numerical_input", + "target", + train_data_frame=train_data, + ) + """ + categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields) + + if not isinstance(categorical_fields, list): + categorical_fields = [categorical_fields] + + if not isinstance(numerical_fields, list): + numerical_fields = [numerical_fields] + + mean, std, classes, codes, target_codes = cls.compute_state( + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, + predict_data_frame=predict_data_frame, + target_fields=target_fields, + numerical_fields=numerical_fields, + categorical_fields=categorical_fields, + ) + + return cls.from_data_source( + "data_frame", + train_data_frame, + val_data_frame, + test_data_frame, + predict_data_frame, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + cat_cols=categorical_fields, + num_cols=numerical_fields, + target_col=target_fields, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=cls.is_regression, + **preprocess_kwargs, + ) + + @classmethod + def from_csv( + cls, + categorical_fields: Optional[Union[str, List[str]]], + numerical_fields: Optional[Union[str, List[str]]], + target_fields: Optional[str] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. + + Args: + categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. + numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_file: The CSV file containing the training data. + val_file: The CSV file containing the validation data. + test_file: The CSV file containing the testing data. + predict_file: The CSV file containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularData.from_csv( + "categorical_input", + "numerical_input", + "target", + train_file="train_data.csv", + ) + """ + return cls.from_data_frame( + categorical_fields=categorical_fields, + numerical_fields=numerical_fields, + target_fields=target_fields, + train_data_frame=pd.read_csv(train_file) if train_file is not None else None, + val_data_frame=pd.read_csv(val_file) if val_file is not None else None, + test_data_frame=pd.read_csv(test_file) if test_file is not None else None, + predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + ) diff --git a/flash/tabular/regression/__init__.py b/flash/tabular/regression/__init__.py new file mode 100644 index 0000000000..a93e599ff0 --- /dev/null +++ b/flash/tabular/regression/__init__.py @@ -0,0 +1 @@ +from flash.tabular.regression.data import TabularRegressionData # noqa: F401 diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py new file mode 100644 index 0000000000..52bd44cd77 --- /dev/null +++ b/flash/tabular/regression/data.py @@ -0,0 +1,18 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.tabular.data import TabularData + + +class TabularRegressionData(TabularData): + is_regression = True diff --git a/flash/template/classification/backbones.py b/flash/template/classification/backbones.py index b36f6a398e..7ea8413003 100644 --- a/flash/template/classification/backbones.py +++ b/flash/template/classification/backbones.py @@ -21,21 +21,27 @@ @TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification") def load_mlp_128(num_features, **_): """A simple MLP backbone with 128 hidden units.""" - return nn.Sequential( - nn.Linear(num_features, 128), - nn.ReLU(True), - nn.BatchNorm1d(128), - ), 128 + return ( + nn.Sequential( + nn.Linear(num_features, 128), + nn.ReLU(True), + nn.BatchNorm1d(128), + ), + 128, + ) @TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification") def load_mlp_128_256(num_features, **_): """An two layer MLP backbone with 128 and 256 hidden units respectively.""" - return nn.Sequential( - nn.Linear(num_features, 128), - nn.ReLU(True), - nn.BatchNorm1d(128), - nn.Linear(128, 256), - nn.ReLU(True), - nn.BatchNorm1d(256), - ), 256 + return ( + nn.Sequential( + nn.Linear(num_features, 128), + nn.ReLU(True), + nn.BatchNorm1d(128), + nn.Linear(128, 256), + nn.ReLU(True), + nn.BatchNorm1d(256), + ), + 256, + ) diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 2624f1c9f3..f81111bc3c 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -33,8 +33,11 @@ class TemplateNumpyDataSource(NumpyDataSource): - """An example data source that records ``num_features`` on the dataset. We extend - :class:`~flash.core.data.data_source.NumpyDataSource` so that we can use ``super().load_data``.""" + """An example data source that records ``num_features`` on the dataset. + + We extend + :class:`~flash.core.data.data_source.NumpyDataSource` so that we can use ``super().load_data``. + """ def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]: """Sets the ``num_features`` attribute and calls ``super().load_data``. @@ -109,16 +112,18 @@ def __init__( ) def get_state_dict(self) -> Dict[str, Any]: - """For serialization, you have control over what to save with the ``get_state_dict`` method. It's usually a good - idea to save the transforms. So we just return them here. If you had any other attributes you wanted to save, - this is where you would return them. + """For serialization, you have control over what to save with the ``get_state_dict`` method. + + It's usually a good idea to save the transforms. So we just return them here. If you had any other attributes + you wanted to save, this is where you would return them. """ return self.transforms @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): - """This methods gets whatever we returned from ``get_state_dict`` as an input. Now we re-create the class with - the transforms we saved. + """This methods gets whatever we returned from ``get_state_dict`` as an input. + + Now we re-create the class with the transforms we saved. """ return cls(**state_dict) @@ -147,8 +152,10 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]: class TemplateData(DataModule): """Creating our :class:`~flash.core.data.data_module.DataModule` is as easy as setting the ``preprocess_cls`` - attribute. We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. - We'll also add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the + attribute. + + We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. We'll also + add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the ``num_features`` property for convenience. """ @@ -232,13 +239,17 @@ def num_features(self) -> Optional[int]: @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher`` method.""" + """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher`` + method.""" return TemplateVisualization(*args, **kwargs) class TemplateVisualization(BaseVisualization): - """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just prints - the data. If you want to provide a visualization with your task, you can override these hooks.""" + """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just + prints the data. + + If you want to provide a visualization with your task, you can override these hooks. + """ def show_load_sample(self, samples: List[Any], running_stage: RunningStage): print(samples) diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index e52faf1274..e330fafdc8 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -26,8 +26,8 @@ class TemplateSKLearnClassifier(ClassificationTask): - """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies - tabular data from scikit-learn. + """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that + classifies tabular data from scikit-learn. Args: num_features: The number of features (elements) in the input data. @@ -112,9 +112,9 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key from - the input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" - batch = (batch[DefaultDataKeys.INPUT]) + """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key + from the input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" + batch = batch[DefaultDataKeys.INPUT] return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) def forward(self, x) -> torch.Tensor: diff --git a/flash/text/__init__.py b/flash/text/__init__.py index 8ac71bdfb5..23786d11f3 100644 --- a/flash/text/__init__.py +++ b/flash/text/__init__.py @@ -1,5 +1,7 @@ from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401 from flash.text.seq2seq import ( # noqa: F401 + QuestionAnsweringData, + QuestionAnsweringTask, Seq2SeqData, Seq2SeqTask, SummarizationData, diff --git a/flash/text/classification/cli.py b/flash/text/classification/cli.py new file mode 100644 index 0000000000..42499bb53f --- /dev/null +++ b/flash/text/classification/cli.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.text import TextClassificationData, TextClassifier + +__all__ = ["text_classification"] + + +def from_imdb( + backbone: str = "prajjwal1/bert-medium", + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> TextClassificationData: + """Downloads and loads the IMDB sentiment classification data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") + return TextClassificationData.from_csv( + "review", + "sentiment", + train_file="data/imdb/train.csv", + val_file="data/imdb/valid.csv", + backbone=backbone, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def from_toxic( + backbone: str = "unitary/toxic-bert", + val_split: float = 0.1, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> TextClassificationData: + """Downloads and loads the Jigsaw toxic comments data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data") + return TextClassificationData.from_csv( + "comment_text", + ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], + train_file="data/jigsaw_toxic_comments/train.csv", + backbone=backbone, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def text_classification(): + """Classify text.""" + cli = FlashCLI( + TextClassifier, + TextClassificationData, + default_datamodule_builder=from_imdb, + additional_datamodule_builders=[from_toxic], + default_arguments={ + "trainer.max_epochs": 3, + }, + datamodule_attributes={"num_classes", "multi_label", "backbone"}, + ) + + cli.trainer.save_checkpoint("text_classification_model.pt") + + +if __name__ == "__main__": + text_classification() diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 826ab87e3d..c7b130543d 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -22,7 +22,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources, LabelsState from flash.core.data.process import Deserializer, Postprocess, Preprocess -from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE +from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras if _TEXT_AVAILABLE: from datasets import DatasetDict, load_dataset @@ -31,9 +31,9 @@ from flash.core.data.data_source import LabelStudioTextDataSource -class TextDeserializer(Deserializer): - @_requires_extras("text") +class TextDeserializer(Deserializer): + @requires_extras("text") def __init__(self, backbone: str, max_length: int, use_fast: bool = True): super().__init__() self.backbone = backbone @@ -58,8 +58,7 @@ def __setstate__(self, state): class TextDataSource(DataSource): - - @_requires_extras("text") + @requires_extras("text") def __init__(self, backbone: str, max_length: int = 128): super().__init__() @@ -93,7 +92,6 @@ def __setstate__(self, state): class TextFileDataSource(TextDataSource): - def __init__(self, filetype: str, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -111,7 +109,10 @@ def load_data( dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), ) -> Union[Sequence[Mapping[str, Any]]]: - file, input, target = data + if self.filetype == "json": + file, input, target, field = data + else: + file, input, target = data data_files = {} @@ -121,21 +122,38 @@ def load_data( # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING and not torch.cuda.is_available(): try: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + if self.filetype == "json" and field is not None: + dataset_dict = DatasetDict( + { + stage: load_dataset( + self.filetype, data_files=data_files, split=[f"{stage}[:20]"], field=field + )[0] + } + ) + else: + dataset_dict = DatasetDict( + {stage: load_dataset(self.filetype, data_files=data_files, split=[f"{stage}[:20]"])[0]} + ) except Exception: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == "json" and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == "json" and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) if not self.predicting: if isinstance(target, List): # multi-target + dataset.multi_label = True dataset_dict = dataset_dict.map(partial(self._multilabel_target, target)) dataset.num_classes = len(target) self.set_state(LabelsState(target)) else: + dataset.multi_label = False if self.training: labels = list(sorted(list(set(dataset_dict[stage][target])))) dataset.num_classes = len(labels) @@ -172,7 +190,6 @@ def __setstate__(self, state): class TextCSVDataSource(TextFileDataSource): - def __init__(self, backbone: str, max_length: int = 128): super().__init__("csv", backbone, max_length=max_length) @@ -187,7 +204,6 @@ def __setstate__(self, state): class TextJSONDataSource(TextFileDataSource): - def __init__(self, backbone: str, max_length: int = 128): super().__init__("json", backbone, max_length=max_length) @@ -202,7 +218,6 @@ def __setstate__(self, state): class TextSentencesDataSource(TextDataSource): - def __init__(self, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -214,7 +229,12 @@ def load_data( if isinstance(data, str): data = [data] - return [self._tokenize_fn(s, ) for s in data] + return [ + self._tokenize_fn( + s, + ) + for s in data + ] def __getstate__(self): # TODO: Find out why this is being pickled state = self.__dict__.copy() @@ -227,8 +247,7 @@ def __setstate__(self, state): class TextClassificationPreprocess(Preprocess): - - @_requires_extras("text") + @requires_extras("text") def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -250,7 +269,9 @@ def __init__( DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), - DefaultDataSources.LABELSTUDIO: LabelStudioTextDataSource(backbone=self.backbone, max_length=max_length) + DefaultDataSources.LABELSTUDIO: LabelStudioTextDataSource( + backbone=self.backbone, max_length=max_length + ), }, default_data_source="sentences", deserializer=TextDeserializer(backbone, max_length), @@ -275,14 +296,13 @@ def per_batch_transform(self, batch: Any) -> Any: return batch def collate(self, samples: Any) -> Tensor: - """Override to convert a set of samples to a batch""" + """Override to convert a set of samples to a batch.""" if isinstance(samples, dict): samples = [samples] return default_data_collator(samples) class TextClassificationPostprocess(Postprocess): - def per_batch_transform(self, batch: Any) -> Any: if isinstance(batch, SequenceClassifierOutput): batch = batch.logits @@ -290,7 +310,11 @@ def per_batch_transform(self, batch: Any) -> Any: class TextClassificationData(DataModule): - """Data Module for text classification tasks""" + """Data Module for text classification tasks.""" preprocess_cls = TextClassificationPreprocess postprocess_cls = TextClassificationPostprocess + + @property + def backbone(self) -> Optional[str]: + return getattr(self.preprocess, "backbone", None) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index e1da47be55..cf339153a0 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -16,15 +16,17 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch -from torchmetrics import Accuracy, F1, Metric +from pytorch_lightning import Callback +from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text.ort_callback import ORTCallback if _TEXT_AVAILABLE: - from transformers import BertForSequenceClassification - from transformers.modeling_outputs import SequenceClassifierOutput + from transformers import AutoModelForSequenceClassification + from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput class TextClassifier(ClassificationTask): @@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask): learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" @@ -57,6 +60,7 @@ def __init__( learning_rate: float = 1e-2, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + enable_ort: bool = False, ): self.save_hyperparameters() @@ -67,49 +71,53 @@ def __init__( os.environ["PYTHONWARNINGS"] = "ignore" super().__init__( + num_classes=num_classes, model=None, loss_fn=loss_fn, optimizer=optimizer, - metrics=metrics or (F1(num_classes) if multi_label else Accuracy()), + metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, serializer=serializer or Labels(multi_label=multi_label), ) - self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) - + self.enable_ort = enable_ort + self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) self.save_hyperparameters() @property def backbone(self): - # see huggingface's BertForSequenceClassification - return self.model.bert + return self.model.base_model def forward(self, batch: Dict[str, torch.Tensor]): return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) def to_loss_format(self, x) -> torch.Tensor: - if isinstance(x, SequenceClassifierOutput): + if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): x = x.logits return super().to_loss_format(x) def to_metrics_format(self, x) -> torch.Tensor: - if isinstance(x, SequenceClassifierOutput): + if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)): x = x.logits return super().to_metrics_format(x) - def step(self, batch, batch_idx) -> dict: + def step(self, batch, batch_idx, metrics) -> dict: target = batch.pop("labels") batch = (batch, target) - return super().step(batch, batch_idx) + return super().step(batch, batch_idx, metrics) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: return self(batch) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" if self.hparams.multi_label: - assert history[-1]["val_f1"] > 0.45 + assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"] else: - assert history[-1]["val_accuracy"] > 0.73 + assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py new file mode 100644 index 0000000000..53b5bdf197 --- /dev/null +++ b/flash/text/ort_callback.py @@ -0,0 +1,51 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +class ORTCallback(Callback): + """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. + + Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for + training and inference. + + Usage: + + # via Transformer Tasks + model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) + + # or via the trainer + trainer = flash.Trainer(callbacks=ORTCallback()) + """ + + def __init__(self): + if not _TORCH_ORT_AVAILABLE: + raise MisconfigurationException( + "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" + ) + + def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: + if not hasattr(pl_module, "model"): + raise MisconfigurationException( + "Torch ORT requires to wrap a single model that defines a forward function " + "assigned as `model` inside the `LightningModule`." + ) + if not isinstance(pl_module.model, ORTModule): + pl_module.model = ORTModule(pl_module.model) diff --git a/flash/text/seq2seq/__init__.py b/flash/text/seq2seq/__init__.py index 1c30bc9d85..88adc2ab65 100644 --- a/flash/text/seq2seq/__init__.py +++ b/flash/text/seq2seq/__init__.py @@ -1,3 +1,4 @@ from flash.text.seq2seq.core import Seq2SeqData, Seq2SeqFreezeEmbeddings, Seq2SeqTask # noqa: F401 +from flash.text.seq2seq.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401 from flash.text.seq2seq.summarization import SummarizationData, SummarizationTask # noqa: F401 from flash.text.seq2seq.translation import TranslationData, TranslationTask # noqa: F401 diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 1b29d7e2c2..60404a5b66 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -23,7 +23,7 @@ from flash.core.data.data_source import DataSource, DefaultDataSources from flash.core.data.process import Postprocess, Preprocess from flash.core.data.properties import ProcessState -from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE +from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras from flash.text.classification.data import TextDeserializer if _TEXT_AVAILABLE: @@ -33,14 +33,13 @@ class Seq2SeqDataSource(DataSource): - - @_requires_extras("text") + @requires_extras("text") def __init__( self, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__() @@ -82,23 +81,25 @@ def __setstate__(self, state): class Seq2SeqFileDataSource(Seq2SeqDataSource): - def __init__( self, filetype: str, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', + padding: Union[str, bool] = "max_length", ): super().__init__(backbone, max_source_length, max_target_length, padding) self.filetype = filetype - def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': + def load_data(self, data: Any, columns: List[str] = None) -> "datasets.Dataset": if columns is None: columns = ["input_ids", "attention_mask", "labels"] - file, input, target = data + if self.filetype == "json": + file, input, target, field = data + else: + file, input, target = data data_files = {} stage = self._running_stage.value data_files[stage] = str(file) @@ -106,19 +107,34 @@ def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING: try: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + if self.filetype == "json" and field is not None: + dataset_dict = DatasetDict( + { + stage: load_dataset( + self.filetype, data_files=data_files, split=[f"{stage}[:20]"], field=field + )[0] + } + ) + else: + dataset_dict = DatasetDict( + {stage: load_dataset(self.filetype, data_files=data_files, split=[f"{stage}[:20]"])[0]} + ) except Exception: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == "json" and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == "json" and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True) dataset_dict.set_format(columns=columns) return dataset_dict[stage] - def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: + def predict_load_data(self, data: Any) -> Union["datasets.Dataset", List[Dict[str, torch.Tensor]]]: return self.load_data(data, columns=["input_ids", "attention_mask"]) def __getstate__(self): # TODO: Find out why this is being pickled @@ -132,13 +148,12 @@ def __setstate__(self, state): class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): - def __init__( self, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', + padding: Union[str, bool] = "max_length", ): super().__init__( "csv", @@ -159,13 +174,12 @@ def __setstate__(self, state): class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): - def __init__( self, backbone: str, max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length', + padding: Union[str, bool] = "max_length", ): super().__init__( "json", @@ -186,7 +200,6 @@ def __setstate__(self, state): class Seq2SeqSentencesDataSource(Seq2SeqDataSource): - def load_data( self, data: Union[str, List[str]], @@ -217,8 +230,7 @@ class Seq2SeqBackboneState(ProcessState): class Seq2SeqPreprocess(Preprocess): - - @_requires_extras("text") + @requires_extras("text") def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -228,7 +240,7 @@ def __init__( backbone: str = "sshleifer/tiny-mbart", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): self.backbone = backbone self.max_target_length = max_target_length @@ -261,7 +273,7 @@ def __init__( ), }, default_data_source="sentences", - deserializer=TextDeserializer(backbone, max_source_length) + deserializer=TextDeserializer(backbone, max_source_length), ) self.set_state(Seq2SeqBackboneState(self.backbone)) @@ -280,13 +292,12 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) def collate(self, samples: Any) -> Tensor: - """Override to convert a set of samples to a batch""" + """Override to convert a set of samples to a batch.""" return default_data_collator(samples) class Seq2SeqPostprocess(Postprocess): - - @_requires_extras("text") + @requires_extras("text") def __init__(self): super().__init__() diff --git a/flash/text/seq2seq/core/finetuning.py b/flash/text/seq2seq/core/finetuning.py index 6d3ea3e512..f75ab65a54 100644 --- a/flash/text/seq2seq/core/finetuning.py +++ b/flash/text/seq2seq/core/finetuning.py @@ -17,9 +17,7 @@ class Seq2SeqFreezeEmbeddings(FlashBaseFinetuning): - """ - Freezes the embedding layers during Seq2Seq training. - """ + """Freezes the embedding layers during Seq2Seq training.""" def __init__(self, model_type: str, train_bn: bool = True): super().__init__("", train_bn) diff --git a/flash/text/seq2seq/summarization/metric.py b/flash/text/seq2seq/core/metrics.py similarity index 51% rename from flash/text/seq2seq/summarization/metric.py rename to flash/text/seq2seq/core/metrics.py index 1e7e7dd3f0..a99c113122 100644 --- a/flash/text/seq2seq/summarization/metric.py +++ b/flash/text/seq2seq/core/metrics.py @@ -11,14 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2020-07-18 +# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +from collections import Counter from typing import Dict, List, Tuple import numpy as np +import torch from torch import tensor from torchmetrics import Metric -from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE -from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence +from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras +from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence if _TEXT_AVAILABLE: from rouge_score import rouge_scorer @@ -27,9 +34,103 @@ AggregateScore, Score, BootstrapAggregator = None, None, object -class RougeMetric(Metric): +def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: + """ + Counting how many times each word appears in a given text with ngram + Args: + ngram_input_list: A list of translated text or reference texts + n_gram: gram value ranged 1 to 4 + + Return: + ngram_counter: a collections.Counter object of ngram """ - Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j : (i + j)]) + ngram_counter[ngram_key] += 1 + + return ngram_counter + + +class BLEUScore(Metric): + """Calculate BLEU score of machine translated text with one or more references. + + Example: + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> metric = BLEUScore() + >>> metric(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + def __init__(self, n_gram: int = 4, smooth: bool = False): + """ + Args: + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + """ + super().__init__() + self.n_gram = n_gram + self.smooth = smooth + + self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") + self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") + + def compute(self): + + trans_len = self.c.clone().detach() + ref_len = self.r.clone().detach() + + if min(self.numerator) == 0.0: + return tensor(0.0, device=self.r.device) + + if self.smooth: + precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0) + else: + precision_scores = self.numerator / self.denominator + + log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, device=self.r.device) * torch.log( + precision_scores + ) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) + bleu = brevity_penalty * geometric_mean + return bleu + + def update(self, translate_corpus, reference_corpus) -> None: + """ + Actual metric computation + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + """ + for (translation, references) in zip(translate_corpus, reference_corpus): + self.c += len(translation) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] + self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = _count_ngram(translation, self.n_gram) + reference_counter = Counter() + + for ref in references: + reference_counter |= _count_ngram(ref, self.n_gram) + + ngram_counter_clip = translation_counter & reference_counter + + for counter_clip in ngram_counter_clip: + self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + self.denominator[len(counter) - 1] += translation_counter[counter] + + +class RougeMetric(Metric): + """Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/ Example: @@ -52,7 +153,7 @@ class RougeMetric(Metric): 'rougeLsum_recall': 0.25} """ - @_requires_extras("text") + @requires_extras("text") def __init__( self, rouge_newline_sep: bool = False, @@ -102,13 +203,11 @@ def __hash__(self): class RougeBatchAggregator(BootstrapAggregator): - """ - Aggregates rouge scores and provides confidence intervals. - """ + """Aggregates rouge scores and provides confidence intervals.""" def aggregate(self): - """ - Override function to wrap the final results in `Score` objects. + """Override function to wrap the final results in `Score` objects. + This is due to the scores being replaced with a list of torch tensors. """ result = {} @@ -118,7 +217,7 @@ def aggregate(self): # Percentiles are returned as (interval, measure). percentiles = self._bootstrap_resample(score_matrix) # Extract the three intervals (low, mid, high). - intervals = tuple((Score(*percentiles[j, :]) for j in range(3))) + intervals = tuple(Score(*percentiles[j, :]) for j in range(3)) result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2]) return result diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index d965c084ae..d79ca18a78 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -16,6 +16,7 @@ from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union import torch +from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_info from torch import Tensor from torchmetrics import Metric @@ -23,6 +24,7 @@ from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings if _TEXT_AVAILABLE: @@ -40,7 +42,7 @@ def _pad_tensors_to_max_len(model_cfg, tensor, max_length): ) padded_tensor = pad_token_id * torch.ones((tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device) - padded_tensor[:, :tensor.shape[-1]] = tensor + padded_tensor[:, : tensor.shape[-1]] = tensor return padded_tensor @@ -54,19 +56,21 @@ class Seq2SeqTask(Task): learning_rate: Learning rate to use for training, defaults to `3e-4` val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" def __init__( self, - backbone: str = 't5-small', + backbone: str = "t5-small", loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, val_target_max_length: Optional[int] = None, num_beams: Optional[int] = None, + enable_ort: bool = False, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -75,6 +79,7 @@ def __init__( os.environ["PYTHONWARNINGS"] = "ignore" super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate) self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone) + self.enable_ort = enable_ort self.val_target_max_length = val_target_max_length self.num_beams = num_beams self._initialize_model_specific_parameters() @@ -83,7 +88,7 @@ def forward(self, x: Any) -> Any: max_length = self.val_target_max_length if self.val_target_max_length else self.model.config.max_length num_beams = self.num_beams if self.num_beams else self.model.config.num_beams generated_tokens = self.model.generate( - input_ids=x['input_ids'], attention_mask=x['attention_mask'], max_length=max_length, num_beams=num_beams + input_ids=x["input_ids"], attention_mask=x["attention_mask"], max_length=max_length, num_beams=num_beams ) # in case the batch is shorter than max length, the output should be padded if generated_tokens.shape[-1] < max_length: @@ -113,9 +118,7 @@ def compute_metrics(self, generated_tokens, batch, prefix): @property def task(self) -> Optional[str]: - """ - Override to define AutoConfig task specific parameters stored within the model. - """ + """Override to define AutoConfig task specific parameters stored within the model.""" return def _initialize_model_specific_parameters(self): @@ -127,7 +130,7 @@ def _initialize_model_specific_parameters(self): self.model.config.update(pars) @property - def tokenizer(self) -> 'PreTrainedTokenizerBase': + def tokenizer(self) -> "PreTrainedTokenizerBase": return self.data_pipeline.data_source.tokenizer def tokenize_labels(self, labels: Tensor) -> List[str]: @@ -136,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]: def configure_finetune_callback(self) -> List[FlashBaseFinetuning]: return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash/text/seq2seq/summarization/utils.py b/flash/text/seq2seq/core/utils.py similarity index 97% rename from flash/text/seq2seq/summarization/utils.py rename to flash/text/seq2seq/core/utils.py index 02647f7264..e48248754c 100644 --- a/flash/text/seq2seq/summarization/utils.py +++ b/flash/text/seq2seq/core/utils.py @@ -16,8 +16,9 @@ from pytorch_lightning.utilities import _module_available nltk = None -if _module_available('nltk'): +if _module_available("nltk"): import nltk + nltk.download("punkt", quiet=True) diff --git a/flash/text/seq2seq/question_answering/__init__.py b/flash/text/seq2seq/question_answering/__init__.py new file mode 100644 index 0000000000..83330ccb4b --- /dev/null +++ b/flash/text/seq2seq/question_answering/__init__.py @@ -0,0 +1,2 @@ +from flash.text.seq2seq.question_answering.data import QuestionAnsweringData # noqa: F401 +from flash.text.seq2seq.question_answering.model import QuestionAnsweringTask # noqa: F401 diff --git a/flash/text/seq2seq/question_answering/data.py b/flash/text/seq2seq/question_answering/data.py new file mode 100644 index 0000000000..ad3f028f20 --- /dev/null +++ b/flash/text/seq2seq/question_answering/data.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Optional, Union + +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess + + +class QuestionAnsweringPreprocess(Seq2SeqPreprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "t5-small", + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = "max_length", + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + backbone=backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ) + + +class QuestionAnsweringData(Seq2SeqData): + + preprocess_cls = QuestionAnsweringPreprocess + postprocess_cls = Seq2SeqPostprocess diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py new file mode 100644 index 0000000000..0ebec8aed3 --- /dev/null +++ b/flash/text/seq2seq/question_answering/model.py @@ -0,0 +1,85 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union + +import torch +from torchmetrics import Metric + +from flash.text.seq2seq.core.metrics import RougeMetric +from flash.text.seq2seq.core.model import Seq2SeqTask + + +class QuestionAnsweringTask(Seq2SeqTask): + """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for Seq2Seq text question answering. For more + details, see `question_answering`. + + You can change the backbone to any question answering model from `HuggingFace/transformers + `_ using the ``backbone`` argument. + + .. note:: When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.Task` and the + :class:`~flash.core.data.data_module.DataModule` object! Since this is a Seq2Seq task, make sure you use a + Seq2Seq model. + + Args: + backbone: backbone model to use for the task. + loss_fn: Loss function for training. + optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. + metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. + Changing this argument currently has no effect. + learning_rate: Learning rate to use for training, defaults to `3e-4` + val_target_max_length: Maximum length of targets in validation. Defaults to `128` + num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` + use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. + rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ + + def __init__( + self, + backbone: str = "t5-small", + loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, + learning_rate: float = 1e-5, + val_target_max_length: Optional[int] = None, + num_beams: Optional[int] = 4, + use_stemmer: bool = True, + rouge_newline_sep: bool = True, + enable_ort: bool = False, + ): + self.save_hyperparameters() + super().__init__( + backbone=backbone, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + val_target_max_length=val_target_max_length, + num_beams=num_beams, + enable_ort=enable_ort, + ) + self.rouge = RougeMetric( + rouge_newline_sep=rouge_newline_sep, + use_stemmer=use_stemmer, + ) + + def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: str) -> None: + tgt_lns = self.tokenize_labels(batch["labels"]) + result = self.rouge(self._postprocess.uncollate(generated_tokens), tgt_lns) + self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True) + + @staticmethod + def _ci_benchmark_fn(history: List[Dict[str, Any]]): + """This function is used only for debugging usage with CI.""" + assert history[-1]["rouge1_recall"] > 0.2 diff --git a/flash/text/seq2seq/summarization/cli.py b/flash/text/seq2seq/summarization/cli.py new file mode 100644 index 0000000000..666dd87f40 --- /dev/null +++ b/flash/text/seq2seq/summarization/cli.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.text import SummarizationData, SummarizationTask + +__all__ = ["summarization"] + + +def from_xsum( + backbone: str = "sshleifer/distilbart-xsum-1-1", + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> SummarizationData: + """Downloads and loads the XSum data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/") + return SummarizationData.from_csv( + "input", + "target", + train_file="data/xsum/train.csv", + val_file="data/xsum/valid.csv", + backbone=backbone, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def summarization(): + """Summarize text.""" + cli = FlashCLI( + SummarizationTask, + SummarizationData, + default_datamodule_builder=from_xsum, + default_arguments={ + "trainer.max_epochs": 3, + "model.backbone": "sshleifer/distilbart-xsum-1-1", + }, + ) + + cli.trainer.save_checkpoint("summarization_model_xsum.pt") + + +if __name__ == "__main__": + summarization() diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index c2a29df52c..3797d97f92 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -17,7 +17,6 @@ class SummarizationPreprocess(Seq2SeqPreprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -27,7 +26,7 @@ def __init__( backbone: str = "sshleifer/distilbart-xsum-1-1", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__( train_transform=train_transform, diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index d547972f3f..19e812baf1 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -16,8 +16,8 @@ import torch from torchmetrics import Metric +from flash.text.seq2seq.core.metrics import RougeMetric from flash.text.seq2seq.core.model import Seq2SeqTask -from flash.text.seq2seq.summarization.metric import RougeMetric class SummarizationTask(Seq2SeqTask): @@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -54,7 +55,8 @@ def __init__( val_target_max_length: Optional[int] = None, num_beams: Optional[int] = 4, use_stemmer: bool = True, - rouge_newline_sep: bool = True + rouge_newline_sep: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( @@ -64,7 +66,8 @@ def __init__( metrics=metrics, learning_rate=learning_rate, val_target_max_length=val_target_max_length, - num_beams=num_beams + num_beams=num_beams, + enable_ort=enable_ort, ) self.rouge = RougeMetric( rouge_newline_sep=rouge_newline_sep, @@ -82,7 +85,5 @@ def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: s @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" assert history[-1]["rouge1_recall"] > 0.2 diff --git a/flash/text/seq2seq/translation/cli.py b/flash/text/seq2seq/translation/cli.py new file mode 100644 index 0000000000..1609cb4de0 --- /dev/null +++ b/flash/text/seq2seq/translation/cli.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.text import TranslationData, TranslationTask + +__all__ = ["translation"] + + +def from_wmt_en_ro( + backbone: str = "Helsinki-NLP/opus-mt-en-ro", + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> TranslationData: + """Downloads and loads the WMT EN RO data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data") + return TranslationData.from_csv( + "input", + "target", + train_file="data/wmt_en_ro/train.csv", + val_file="data/wmt_en_ro/valid.csv", + backbone=backbone, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def translation(): + """Translate text.""" + cli = FlashCLI( + TranslationTask, + TranslationData, + default_datamodule_builder=from_wmt_en_ro, + default_arguments={ + "trainer.max_epochs": 3, + "model.backbone": "Helsinki-NLP/opus-mt-en-ro", + }, + ) + + cli.trainer.save_checkpoint("translation_model_en_ro.pt") + + +if __name__ == "__main__": + translation() diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 0b9e7a3ce7..5485be1003 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -17,7 +17,6 @@ class TranslationPreprocess(Seq2SeqPreprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -27,7 +26,7 @@ def __init__( backbone: str = "t5-small", max_source_length: int = 128, max_target_length: int = 128, - padding: Union[str, bool] = 'max_length' + padding: Union[str, bool] = "max_length", ): super().__init__( train_transform=train_transform, diff --git a/flash/text/seq2seq/translation/metric.py b/flash/text/seq2seq/translation/metric.py deleted file mode 100644 index bd3e4fe872..0000000000 --- a/flash/text/seq2seq/translation/metric.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# referenced from -# Library Name: torchtext -# Authors: torchtext authors and @sluks -# Date: 2020-07-18 -# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from collections import Counter -from typing import List - -import torch -from torch import tensor -from torchmetrics import Metric - - -def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: - """ - Counting how many times each word appears in a given text with ngram - Args: - ngram_input_list: A list of translated text or reference texts - n_gram: gram value ranged 1 to 4 - - Return: - ngram_counter: a collections.Counter object of ngram - """ - - ngram_counter = Counter() - - for i in range(1, n_gram + 1): - for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j:(i + j)]) - ngram_counter[ngram_key] += 1 - - return ngram_counter - - -class BLEUScore(Metric): - """ - Calculate BLEU score of machine translated text with one or more references. - - Example: - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - >>> metric = BLEUScore() - >>> metric(translate_corpus, reference_corpus) - tensor(0.7598) - """ - - def __init__(self, n_gram: int = 4, smooth: bool = False): - """ - Args: - n_gram: Gram value ranged from 1 to 4 (Default 4) - smooth: Whether or not to apply smoothing – Lin et al. 2004 - """ - super().__init__() - self.n_gram = n_gram - self.smooth = smooth - - self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - - def compute(self): - - trans_len = self.c.clone().detach() - ref_len = self.r.clone().detach() - - if min(self.numerator) == 0.0: - return tensor(0.0, device=self.r.device) - - if self.smooth: - precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0) - else: - precision_scores = self.numerator / self.denominator - - log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, - device=self.r.device) * torch.log(precision_scores) - geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = ( - tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) - ) - bleu = brevity_penalty * geometric_mean - return bleu - - def update(self, translate_corpus, reference_corpus) -> None: - """ - Actual metric computation - Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus - """ - for (translation, references) in zip(translate_corpus, reference_corpus): - self.c += len(translation) - ref_len_list = [len(ref) for ref in references] - ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] - self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] - translation_counter = _count_ngram(translation, self.n_gram) - reference_counter = Counter() - - for ref in references: - reference_counter |= _count_ngram(ref, self.n_gram) - - ngram_counter_clip = translation_counter & reference_counter - - for counter_clip in ngram_counter_clip: - self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] - - for counter in translation_counter: - self.denominator[len(counter) - 1] += translation_counter[counter] diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index a9ac0a6a31..c70089e8d6 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -16,8 +16,8 @@ import torch from torchmetrics import Metric +from flash.text.seq2seq.core.metrics import BLEUScore from flash.text.seq2seq.core.model import Seq2SeqTask -from flash.text.seq2seq.translation.metric import BLEUScore class TranslationTask(Seq2SeqTask): @@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask): num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` smooth: Apply smoothing in BLEU calculation. Defaults to `True` + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ def __init__( @@ -55,6 +56,7 @@ def __init__( num_beams: Optional[int] = 4, n_gram: bool = 4, smooth: bool = True, + enable_ort: bool = False, ): self.save_hyperparameters() super().__init__( @@ -65,6 +67,7 @@ def __init__( learning_rate=learning_rate, val_target_max_length=val_target_max_length, num_beams=num_beams, + enable_ort=enable_ort, ) self.bleu = BLEUScore( n_gram=n_gram, @@ -84,7 +87,5 @@ def compute_metrics(self, generated_tokens, batch, prefix): @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ + """This function is used only for debugging usage with CI.""" assert history[-1]["val_bleu_score"] > 0.6 diff --git a/flash/video/classification/cli.py b/flash/video/classification/cli.py new file mode 100644 index 0000000000..840386506b --- /dev/null +++ b/flash/video/classification/cli.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Optional + +from flash.core.data.utils import download_data +from flash.core.utilities.flash_cli import FlashCLI +from flash.video import VideoClassificationData, VideoClassifier + +__all__ = ["video_classification"] + + +def from_kinetics( + clip_sampler: str = "uniform", + clip_duration: int = 1, + decode_audio: bool = False, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs, +) -> VideoClassificationData: + """Downloads and loads the Kinetics data set.""" + download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data") + return VideoClassificationData.from_folders( + train_folder=os.path.join(os.getcwd(), "data/kinetics/train"), + val_folder=os.path.join(os.getcwd(), "data/kinetics/val"), + clip_sampler=clip_sampler, + clip_duration=clip_duration, + decode_audio=decode_audio, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +def video_classification(): + """Classify videos.""" + cli = FlashCLI( + VideoClassifier, + VideoClassificationData, + default_datamodule_builder=from_kinetics, + default_arguments={ + "trainer.max_epochs": 3, + }, + ) + + cli.trainer.save_checkpoint("video_classification.pt") + + +if __name__ == "__main__": + video_classification() diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index db9cddc6f8..a57de670d7 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -25,8 +25,8 @@ DefaultDataSources, FiftyOneDataSource, LabelsState, + LabelStudioVideoDataSource, PathsDataSource, - LabelStudioVideoDataSource ) from flash.core.data.process import Preprocess from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import @@ -45,21 +45,20 @@ if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler from pytorchvideo.data.encoded_video import EncodedVideo - from pytorchvideo.data.encoded_video_dataset import EncodedVideoDataset, labeled_encoded_video_dataset + from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip else: - ClipSampler, EncodedVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None + ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None _PYTORCHVIDEO_DATA = Dict[str, Union[str, torch.Tensor, int, float, List]] -class BaseVideoClassification(object): - +class BaseVideoClassification: def __init__( self, - clip_sampler: 'ClipSampler', + clip_sampler: "ClipSampler", video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", @@ -69,28 +68,31 @@ def __init__( self.decode_audio = decode_audio self.decoder = decoder - def load_data(self, data: str, dataset: Optional[Any] = None) -> 'EncodedVideoDataset': + def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDataset": ds = self._make_encoded_video_dataset(data) if self.training: label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels} self.set_state(LabelsState(label_to_class_mapping)) - dataset.num_classes = len(np.unique([s[1]['label'] for s in ds._labeled_videos])) + dataset.num_classes = len(np.unique([s[1]["label"] for s in ds._labeled_videos])) return ds + def load_sample(self, sample): + return sample + def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: video_path = sample[DefaultDataKeys.INPUT] sample.update(self._encoded_video_to_dict(EncodedVideo.from_path(video_path))) sample[DefaultDataKeys.METADATA] = {"filepath": video_path} return sample - def _encoded_video_to_dict(self, video) -> Dict[str, Any]: + def _encoded_video_to_dict(self, video, annotation: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: ( clip_start, clip_end, clip_index, aug_index, is_last_clip, - ) = self.clip_sampler(0.0, video.duration) + ) = self.clip_sampler(0.0, video.duration, annotation) loaded_clip = video.get_clip(clip_start, clip_end) @@ -111,20 +113,17 @@ def _encoded_video_to_dict(self, video) -> Dict[str, Any]: "video_index": 0, "clip_index": clip_index, "aug_index": aug_index, - **({ - "audio": audio_samples - } if audio_samples is not None else {}), + **({"audio": audio_samples} if audio_samples is not None else {}), } - def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset': + def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset": raise NotImplementedError("Subclass must implement _make_encoded_video_dataset()") class VideoClassificationPathsDataSource(BaseVideoClassification, PathsDataSource): - def __init__( self, - clip_sampler: 'ClipSampler', + clip_sampler: "ClipSampler", video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", @@ -140,8 +139,8 @@ def __init__( extensions=("mp4", "avi"), ) - def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset': - ds: EncodedVideoDataset = labeled_encoded_video_dataset( + def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset": + ds: LabeledVideoDataset = labeled_video_dataset( pathlib.Path(data), self.clip_sampler, video_sampler=self.video_sampler, @@ -155,10 +154,9 @@ class VideoClassificationFiftyOneDataSource( BaseVideoClassification, FiftyOneDataSource, ): - def __init__( self, - clip_sampler: 'ClipSampler', + clip_sampler: "ClipSampler", video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = True, decoder: str = "pyav", @@ -179,7 +177,7 @@ def __init__( def label_cls(self): return fol.Classification - def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDataset': + def _make_encoded_video_dataset(self, data: SampleCollection) -> "LabeledVideoDataset": classes = self._get_classes(data) label_to_class_mapping = dict(enumerate(classes)) class_to_label_mapping = {c: lab for lab, c in label_to_class_mapping.items()} @@ -189,7 +187,7 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDa targets = [class_to_label_mapping[lab] for lab in labels] labeled_video_paths = LabeledVideoPaths(list(zip(filepaths, targets))) - ds: EncodedVideoDataset = EncodedVideoDataset( + ds: LabeledVideoDataset = LabeledVideoDataset( labeled_video_paths, self.clip_sampler, video_sampler=self.video_sampler, @@ -200,14 +198,13 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDa class VideoClassificationPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, - clip_sampler: Union[str, 'ClipSampler'] = "random", + clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, @@ -266,7 +263,7 @@ def __init__( decode_audio=decode_audio, decoder=decoder, **data_source_kwargs, - ) + ), }, default_data_source=DefaultDataSources.FILES, ) @@ -283,7 +280,7 @@ def get_state_dict(self) -> Dict[str, Any]: } @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess': + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> "VideoClassificationPreprocess": return cls(**state_dict) def default_transforms(self) -> Dict[str, Callable]: @@ -298,22 +295,26 @@ def default_transforms(self) -> Dict[str, Callable]: ] return { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), + "post_tensor_transform": Compose( + [ + ApplyTransformToKey( + key="video", + transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform), + ), + ] + ), + "per_batch_transform_on_device": Compose( + [ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + data_format="BCTHW", + same_on_frame=False, + ), + ), + ] + ), } diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 8e05069a2b..9345b7b19b 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -31,20 +31,21 @@ from flash.core.data.process import Serializer from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE +from flash.core.utilities.providers import _PYTORCHVIDEO _VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones") if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.models import hub + for fn_name in dir(hub): if "__" not in fn_name: fn = getattr(hub, fn_name) if isinstance(fn, FunctionType): - _VIDEO_CLASSIFIER_BACKBONES(fn=fn) + _VIDEO_CLASSIFIER_BACKBONES(fn=fn, providers=_PYTORCHVIDEO) class VideoClassifierFinetuning(BaseFinetuning): - def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1): super().__init__() self.num_layers = num_layers @@ -52,7 +53,7 @@ def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: i self.unfreeze_epoch = unfreeze_epoch def freeze_before_training(self, pl_module: LightningModule) -> None: - self.freeze(modules=list(pl_module.backbone.children())[:-self.num_layers], train_bn=self.train_bn) + self.freeze(modules=list(pl_module.backbone.children())[: -self.num_layers], train_bn=self.train_bn) def finetune_function( self, @@ -64,7 +65,7 @@ def finetune_function( if epoch != self.unfreeze_epoch: return self.unfreeze_and_add_param_group( - modules=list(pl_module.backbone.children())[-self.num_layers:], + modules=list(pl_module.backbone.children())[-self.num_layers :], optimizer=optimizer, train_bn=self.train_bn, ) @@ -94,7 +95,7 @@ class VideoClassifier(ClassificationTask): def __init__( self, num_classes: int, - backbone: Union[str, nn.Module] = "slow_r50", + backbone: Union[str, nn.Module] = "x3d_xs", backbone_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Callable = F.cross_entropy, @@ -110,7 +111,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - serializer=serializer or Labels() + serializer=serializer or Labels(), ) self.save_hyperparameters() @@ -146,8 +147,8 @@ def on_train_epoch_start(self) -> None: encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch) super().on_train_epoch_start() - def step(self, batch: Any, batch_idx: int) -> Any: - return super().step((batch["video"], batch["label"]), batch_idx) + def step(self, batch: Any, batch_idx: int, metrics) -> Any: + return super().step((batch["video"], batch["label"]), batch_idx, metrics) def forward(self, x: Any) -> Any: x = self.backbone(x) @@ -165,7 +166,5 @@ def configure_finetune_callback(self) -> List[Callback]: @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): - """ - This function is used only for debugging usage with CI - """ - assert history[-1]["val_accuracy"] > 0.80 + """This function is used only for debugging usage with CI.""" + assert history[-1]["val_accuracy"] > 0.70 diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py new file mode 100644 index 0000000000..6dea056c18 --- /dev/null +++ b/flash_examples/audio_classification.py @@ -0,0 +1,49 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.audio import AudioClassificationData +from flash.core.data.utils import download_data +from flash.core.finetuning import FreezeUnfreeze +from flash.image import ImageClassifier + +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") + +datamodule = AudioClassificationData.from_folders( + train_folder="data/urban8k_images/train", + val_folder="data/urban8k_images/val", + spectrogram_size=(64, 64), +) + +# 2. Build the model. +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) + +# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c +predictions = model.predict( + [ + "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", + "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", + "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("audio_classification_model.pt") diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py index 2ab29f6526..15cc3b9fc7 100644 --- a/flash_examples/custom_task.py +++ b/flash_examples/custom_task.py @@ -35,7 +35,6 @@ class RegressionTask(flash.Task): - def __init__(self, num_inputs, learning_rate=0.2, metrics=None): # what kind of model do we want? model = nn.Linear(num_inputs, 1) @@ -85,7 +84,6 @@ def forward(self, x): class NumpyDataSource(DataSource[Tuple[ND, ND]]): - def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]: if self.training: dataset.num_inputs = data[0].shape[1] @@ -97,7 +95,6 @@ def predict_load_data(data: ND) -> List[Dict[str, Any]]: class NumpyPreprocess(Preprocess): - def __init__( self, train_transform: Optional[Dict[str, Callable]] = None, @@ -160,16 +157,20 @@ class NumpyDataModule(flash.DataModule): datamodule = NumpyDataModule.from_numpy(x, y) model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs) -trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False) +trainer = flash.Trainer( + max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count() +) trainer.fit(model, datamodule=datamodule) -predict_data = np.array([ - [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], - [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], - [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], - [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], - [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], -]) +predict_data = np.array( + [ + [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403], + [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072], + [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072], + [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384], + [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094], + ] +) predictions = model.predict(predict_data) print(predictions) diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py new file mode 100644 index 0000000000..4519f70c33 --- /dev/null +++ b/flash_examples/graph_classification.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.utilities.imports import example_requires +from flash.graph import GraphClassificationData, GraphClassifier + +example_requires("graph") + +from torch_geometric.datasets import TUDataset # noqa: E402 + +# 1. Create the DataModule +dataset = TUDataset(root="data", name="KKI") + +datamodule = GraphClassificationData.from_datasets( + train_dataset=dataset, + val_split=0.1, +) + +# 2. Build the task +model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) + +# 3. Create the trainer and fit the model +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.fit(model, datamodule=datamodule) + +# 4. Classify some graphs! +predictions = model.predict(dataset[:3]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("graph_classification.pt") diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index a675938c57..3b9413a629 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier @@ -27,15 +29,17 @@ model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict what's on a few images! ants or bees? -predictions = model.predict([ - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", -]) +predictions = model.predict( + [ + "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", + "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", + "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py index 00e86d7f0b..947446a9c0 100644 --- a/flash_examples/image_classification_multi_label.py +++ b/flash_examples/image_classification_multi_label.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path as osp -from typing import List, Tuple - -import pandas as pd +import torch import flash from flash.core.data.utils import download_data @@ -24,37 +21,31 @@ # Data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks”. # More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip") -genres = ["Action", "Romance", "Crime", "Thriller", "Adventure"] - - -def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]: - metadata = pd.read_csv(osp.join(root, data, "metadata.csv")) - return ([osp.join(root, data, row['Id'] + ".jpg") for _, row in metadata.iterrows()], - [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()]) - -train_files, train_targets = load_data('train') -datamodule = ImageClassificationData.from_files( - train_files=train_files, - train_targets=train_targets, - val_split=0.1, +datamodule = ImageClassificationData.from_csv( + "Id", + ["Action", "Romance", "Crime", "Thriller", "Adventure"], + train_file="data/movie_posters/train/metadata.csv", + val_file="data/movie_posters/val/metadata.csv", image_size=(128, 128), ) # 2. Build the task -model = ImageClassifier(backbone="resnet18", num_classes=len(genres), multi_label=True) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Predict the genre of a few movies! -predictions = model.predict([ - "data/movie_posters/predict/tt0085318.jpg", - "data/movie_posters/predict/tt0089461.jpg", - "data/movie_posters/predict/tt0097179.jpg", -]) +predictions = model.predict( + [ + "data/movie_posters/predict/tt0085318.jpg", + "data/movie_posters/predict/tt0089461.jpg", + "data/movie_posters/predict/tt0097179.jpg", + ] +) print(predictions) -# 7. Save the model! +# 5. Save the model! trainer.save_checkpoint("image_classification_multi_label_model.pt") diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py index cd786472c3..5a4de94fcf 100644 --- a/flash_examples/image_embedder.py +++ b/flash_examples/image_embedder.py @@ -22,3 +22,4 @@ # 3. Generate an embedding from an image path. embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"]) +print(embeddings) diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py new file mode 100644 index 0000000000..3fdc4e8a4b --- /dev/null +++ b/flash_examples/instance_segmentation.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import flash +from flash.core.utilities.imports import example_requires +from flash.image import InstanceSegmentation, InstanceSegmentationData + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.pets.load_data() + +datamodule = InstanceSegmentationData.from_folders( + train_folder=data_dir, + val_split=0.1, + parser=partial(icedata.pets.parser, mask=True), +) + +# 2. Build the task +model = InstanceSegmentation( + head="mask_rcnn", + backbone="resnet18_fpn", + num_classes=datamodule.num_classes, +) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect objects in a few images! +predictions = model.predict( + [ + str(data_dir / "images/yorkshire_terrier_9.jpg"), + str(data_dir / "images/english_cocker_spaniel_1.jpg"), + str(data_dir / "images/scottish_terrier_1.jpg"), + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("instance_segmentation_model.pt") diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py index ebf40df56c..b1f5fb56cf 100644 --- a/flash_examples/integrations/fiftyone/image_classification.py +++ b/flash_examples/integrations/fiftyone/image_classification.py @@ -13,6 +13,8 @@ # limitations under the License. from itertools import chain +import torch + import flash from flash.core.classification import FiftyOneLabels, Labels from flash.core.data.utils import download_data @@ -39,6 +41,7 @@ ) trainer = flash.Trainer( max_epochs=1, + gpus=torch.cuda.device_count(), limit_train_batches=1, limit_val_batches=1, ) diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py index 5ec81bdf6f..9ef31609d5 100644 --- a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py @@ -14,6 +14,7 @@ from itertools import chain import fiftyone as fo +import torch import flash from flash.core.classification import FiftyOneLabels, Labels @@ -53,6 +54,7 @@ ) trainer = flash.Trainer( max_epochs=1, + gpus=torch.cuda.device_count(), limit_train_batches=1, limit_val_batches=1, ) diff --git a/flash_examples/integrations/fiftyone/image_embedding.py b/flash_examples/integrations/fiftyone/image_embedding.py index b9d1651ceb..019bd9cffe 100644 --- a/flash_examples/integrations/fiftyone/image_embedding.py +++ b/flash_examples/integrations/fiftyone/image_embedding.py @@ -28,7 +28,7 @@ ) # 3 Load model -embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128) +embedder = ImageEmbedder(backbone="resnet101") # 4 Generate embeddings filepaths = dataset.values("filepath") diff --git a/flash_examples/integrations/labelstudio/image_classification.py b/flash_examples/integrations/labelstudio/image_classification.py index 1ef31c6b2e..12d9df7952 100644 --- a/flash_examples/integrations/labelstudio/image_classification.py +++ b/flash_examples/integrations/labelstudio/image_classification.py @@ -9,8 +9,8 @@ # 1. Load export data datamodule = ImageClassificationData.from_labelstudio( - export_json='data/project.json', - data_folder='data/upload/', + export_json="data/project.json", + data_folder="data/upload/", val_split=0.8, ) diff --git a/flash_examples/integrations/labelstudio/text_classification.py b/flash_examples/integrations/labelstudio/text_classification.py index 4d4f260991..88a315d535 100644 --- a/flash_examples/integrations/labelstudio/text_classification.py +++ b/flash_examples/integrations/labelstudio/text_classification.py @@ -8,7 +8,7 @@ backbone = "prajjwal1/bert-medium" datamodule = TextClassificationData.from_labelstudio( - export_json='data/project.json', + export_json="data/project.json", val_split=0.8, backbone=backbone, ) diff --git a/flash_examples/integrations/labelstudio/video_classification.py b/flash_examples/integrations/labelstudio/video_classification.py index 26315fe4a9..af4c206590 100644 --- a/flash_examples/integrations/labelstudio/video_classification.py +++ b/flash_examples/integrations/labelstudio/video_classification.py @@ -9,8 +9,8 @@ # 1. Load export data datamodule = VideoClassificationData.from_labelstudio( - export_json='data/project.json', - data_folder='data/upload/', + export_json="data/project.json", + data_folder="data/upload/", val_split=0.8, clip_sampler="uniform", clip_duration=1, diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py new file mode 100644 index 0000000000..b1fa29cc02 --- /dev/null +++ b/flash_examples/keypoint_detection.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.utilities.imports import example_requires +from flash.image import KeypointDetectionData, KeypointDetector + +example_requires("image") + +import icedata # noqa: E402 + +# 1. Create the DataModule +data_dir = icedata.biwi.load_data() + +datamodule = KeypointDetectionData.from_folders( + train_folder=data_dir, + val_split=0.1, + parser=icedata.biwi.parser, +) + +# 2. Build the task +model = KeypointDetector( + head="keypoint_rcnn", + backbone="resnet18_fpn", + num_keypoints=1, + num_classes=datamodule.num_classes, +) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect objects in a few images! +predictions = model.predict( + [ + str(data_dir / "biwi_sample/images/0.jpg"), + str(data_dir / "biwi_sample/images/1.jpg"), + str(data_dir / "biwi_sample/images/10.jpg"), + ] +) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("keypoint_detection_model.pt") diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 4f488e1e11..1a5dddbce9 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -17,27 +17,30 @@ # 1. Create the DataModule # Dataset Credit: https://www.kaggle.com/ultralytics/coco128 -download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "finetuning/data/") +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") datamodule = ObjectDetectionData.from_coco( train_folder="data/coco128/images/train2017/", - train_ann_file="finetuning/data/coco128/annotations/instances_train2017.json", + train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, + image_size=128, ) # 2. Build the task -model = ObjectDetector(model="retinanet", num_classes=datamodule.num_classes) +model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) -trainer.finetune(model, datamodule=datamodule) +trainer = flash.Trainer(max_epochs=1) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! -predictions = model.predict([ - "data/coco128/images/train2017/000000000625.jpg", - "data/coco128/images/train2017/000000000626.jpg", - "data/coco128/images/train2017/000000000629.jpg", -]) +predictions = model.predict( + [ + "data/coco128/images/train2017/000000000625.jpg", + "data/coco128/images/train2017/000000000626.jpg", + "data/coco128/images/train2017/000000000629.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/pointcloud_detection.py b/flash_examples/pointcloud_detection.py new file mode 100644 index 0000000000..ff29265355 --- /dev/null +++ b/flash_examples/pointcloud_detection.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.data.utils import download_data +from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/") + +datamodule = PointCloudObjectDetectorData.from_folders( + train_folder="data/KITTI_Tiny/Kitti/train", + val_folder="data/KITTI_Tiny/Kitti/val", +) + +# 2. Build the task +model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer( + max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count() +) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict( + [ + "data/KITTI_Tiny/Kitti/predict/scans/000000.bin", + "data/KITTI_Tiny/Kitti/predict/scans/000001.bin", + ] +) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_detection_model.pt") diff --git a/flash_examples/pointcloud_segmentation.py b/flash_examples/pointcloud_segmentation.py new file mode 100644 index 0000000000..7d1a0eb538 --- /dev/null +++ b/flash_examples/pointcloud_segmentation.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.data.utils import download_data +from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") + +datamodule = PointCloudSegmentationData.from_folders( + train_folder="data/SemanticKittiTiny/train", + val_folder="data/SemanticKittiTiny/val", +) + +# 2. Build the task +model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer( + max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count() +) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict( + [ + "data/SemanticKittiTiny/predict/000000.bin", + "data/SemanticKittiTiny/predict/000001.bin", + ] +) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_examples/semantic_segmentation.py b/flash_examples/semantic_segmentation.py index 83aa617c62..a3800f2508 100644 --- a/flash_examples/semantic_segmentation.py +++ b/flash_examples/semantic_segmentation.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash.core.data.utils import download_data from flash.image import SemanticSegmentation, SemanticSegmentationData @@ -20,34 +22,36 @@ # More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge download_data( "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", - "./data" + "./data", ) datamodule = SemanticSegmentationData.from_folders( train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", val_split=0.1, - image_size=(200, 200), + image_size=(256, 256), num_classes=21, ) # 2. Build the task model = SemanticSegmentation( - backbone="mobilenet_v3_large", - head="fcn", + backbone="mobilenetv3_large_100", + head="fpn", num_classes=datamodule.num_classes, ) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Segment a few images! -predictions = model.predict([ - "data/CameraRGB/F61-1.png", - "data/CameraRGB/F62-1.png", - "data/CameraRGB/F63-1.png", -]) +predictions = model.predict( + [ + "data/CameraRGB/F61-1.png", + "data/CameraRGB/F62-1.png", + "data/CameraRGB/F63-1.png", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/serve/generic/boston_prediction/inference_server.py b/flash_examples/serve/generic/boston_prediction/inference_server.py index 1e1d958e9f..acd1735ae9 100644 --- a/flash_examples/serve/generic/boston_prediction/inference_server.py +++ b/flash_examples/serve/generic/boston_prediction/inference_server.py @@ -35,7 +35,6 @@ class PricePrediction(ModelComponent): - def __init__(self, model): # skipcq: PYL-W0621 self.model = model diff --git a/flash_examples/serve/generic/detection/inference.py b/flash_examples/serve/generic/detection/inference.py index 0971fb380c..813359a6dc 100644 --- a/flash_examples/serve/generic/detection/inference.py +++ b/flash_examples/serve/generic/detection/inference.py @@ -18,16 +18,12 @@ class ObjectDetection(ModelComponent): - def __init__(self, model): self.model = model @expose( inputs={"img": Image()}, - outputs={ - "boxes": Repeated(BBox()), - "labels": Repeated(Label("classes.txt")) - }, + outputs={"boxes": Repeated(BBox()), "labels": Repeated(Label("classes.txt"))}, ) def detect(self, img): img = img.permute(0, 3, 2, 1).float() / 255 diff --git a/flash/image/detection/finetuning.py b/flash_examples/serve/speech_recognition/client.py similarity index 55% rename from flash/image/detection/finetuning.py rename to flash_examples/serve/speech_recognition/client.py index c1ca20072d..c855a37204 100644 --- a/flash/image/detection/finetuning.py +++ b/flash_examples/serve/speech_recognition/client.py @@ -11,19 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import base64 +from pathlib import Path -from flash.core.finetuning import FlashBaseFinetuning +import requests +import flash -class ObjectDetectionFineTuning(FlashBaseFinetuning): - """ - Freezes the backbone during Detector training. - """ +with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f: + audio_str = base64.b64encode(f.read()).decode("UTF-8") - def __init__(self, train_bn: bool = True) -> None: - super().__init__(train_bn=train_bn) +body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}} +resp = requests.post("http://127.0.0.1:8000/predict", json=body) - def freeze_before_training(self, pl_module: pl.LightningModule) -> None: - model = pl_module.model - self.freeze(modules=model.backbone, train_bn=self.train_bn) +print(resp.json()) diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/flash_examples/serve/speech_recognition/inference_server.py new file mode 100644 index 0000000000..bbc4479624 --- /dev/null +++ b/flash_examples/serve/speech_recognition/inference_server.py @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.audio import SpeechRecognition + +model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt") +model.serve() diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py index f6aac866e2..4b58b8f691 100644 --- a/flash_examples/serve/tabular_classification/inference_server.py +++ b/flash_examples/serve/tabular_classification/inference_server.py @@ -15,5 +15,5 @@ from flash.tabular import TabularClassifier model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt") -model.serializer = Labels(['Did not survive', 'Survived']) +model.serializer = Labels(["Did not survive", "Survived"]) model.serve() diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py new file mode 100644 index 0000000000..1672dbe1fe --- /dev/null +++ b/flash_examples/speech_recognition.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.audio import SpeechRecognition, SpeechRecognitionData +from flash.core.data.utils import download_data + +# # 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") + +datamodule = SpeechRecognitionData.from_json( + input_fields="file", + target_fields="text", + train_file="data/timit/train.json", + test_file="data/timit/test.json", +) + +# 2. Build the task +model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h") + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) +trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") + +# 4. Predict on audio files! +predictions = model.predict(["data/timit/example.wav"]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("speech_recognition_model.pt") diff --git a/flash_examples/style_transfer.py b/flash_examples/style_transfer.py index 37500e9358..607f5ad0f6 100644 --- a/flash_examples/style_transfer.py +++ b/flash_examples/style_transfer.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import torch + import flash from flash.core.data.utils import download_data from flash.image.style_transfer import StyleTransfer, StyleTransferData @@ -26,15 +28,17 @@ model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg")) # 3. Create the trainer and train the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.fit(model, datamodule=datamodule) # 4. Apply style transfer to a few images! -predictions = model.predict([ - "data/coco128/images/train2017/000000000625.jpg", - "data/coco128/images/train2017/000000000626.jpg", - "data/coco128/images/train2017/000000000629.jpg", -]) +predictions = model.predict( + [ + "data/coco128/images/train2017/000000000625.jpg", + "data/coco128/images/train2017/000000000626.jpg", + "data/coco128/images/train2017/000000000629.jpg", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py index fa3a2cc23e..ef80723afa 100644 --- a/flash_examples/tabular_classification.py +++ b/flash_examples/tabular_classification.py @@ -11,14 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash.core.data.utils import download_data -from flash.tabular import TabularClassifier, TabularData +from flash.tabular import TabularClassificationData, TabularClassifier # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") -datamodule = TabularData.from_csv( +datamodule = TabularClassificationData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", @@ -30,7 +32,7 @@ model = TabularClassifier.from_data(datamodule) # 3. Create the trainer and train the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.fit(model, datamodule=datamodule) # 4. Generate predictions from a CSV diff --git a/flash_examples/template.py b/flash_examples/template.py index 66ce579a83..0d8c7016ed 100644 --- a/flash_examples/template.py +++ b/flash_examples/template.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np +import torch from sklearn import datasets import flash @@ -27,15 +28,17 @@ model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) # 3. Create the trainer and train the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.fit(model, datamodule=datamodule) # 4. Classify a few examples -predictions = model.predict([ - np.array([4.9, 3.0, 1.4, 0.2]), - np.array([6.9, 3.2, 5.7, 2.3]), - np.array([7.2, 3.0, 5.8, 1.6]), -]) +predictions = model.predict( + [ + np.array([4.9, 3.0, 1.4, 0.2]), + np.array([6.9, 3.2, 5.7, 2.3]), + np.array([7.2, 3.0, 5.8, 1.6]), + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index 1924d408de..3d62dbb0dc 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier @@ -30,15 +32,17 @@ model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Classify a few sentences! How was the movie? -predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado.", -]) +predictions = model.predict( + [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification_multi_label.py b/flash_examples/text_classification_multi_label.py index 57222bf560..72f87b7c81 100644 --- a/flash_examples/text_classification_multi_label.py +++ b/flash_examples/text_classification_multi_label.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash.core.data.utils import download_data from flash.text import TextClassificationData, TextClassifier @@ -36,15 +38,17 @@ ) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Generate predictions for a few comments! -predictions = model.predict([ - "No, he is an arrogant, self serving, immature idiot. Get it right.", - "U SUCK HANNAH MONTANA", - "Would you care to vote? Thx.", -]) +predictions = model.predict( + [ + "No, he is an arrogant, self serving, immature idiot. Get it right.", + "U SUCK HANNAH MONTANA", + "Would you care to vote? Thx.", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/translation.py b/flash_examples/translation.py index 2a0d7889f2..fc82bb767a 100644 --- a/flash_examples/translation.py +++ b/flash_examples/translation.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import flash from flash.core.data.utils import download_data from flash.text import TranslationData, TranslationTask @@ -30,15 +32,17 @@ model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro") # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule) # 4. Translate something! -predictions = model.predict([ - "BBC News went to meet one of the project's first graduates.", - "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", - "Of course, it's still early in the election cycle.", -]) +predictions = model.predict( + [ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + "Of course, it's still early in the election cycle.", + ] +) print(predictions) # 5. Save the model! diff --git a/flash_examples/video_classification.py b/flash_examples/video_classification.py index 1ecfd25959..99c7422dcd 100644 --- a/flash_examples/video_classification.py +++ b/flash_examples/video_classification.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import torch + import flash from flash.core.data.utils import download_data from flash.video import VideoClassificationData, VideoClassifier @@ -33,7 +35,7 @@ model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Make a prediction diff --git a/flash_examples/visualizations/pointcloud_detection.py b/flash_examples/visualizations/pointcloud_detection.py new file mode 100644 index 0000000000..899e30a3aa --- /dev/null +++ b/flash_examples/visualizations/pointcloud_detection.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.data.utils import download_data +from flash.pointcloud.detection import launch_app, PointCloudObjectDetector, PointCloudObjectDetectorData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/") + +datamodule = PointCloudObjectDetectorData.from_folders( + train_folder="data/KITTI_Tiny/Kitti/train", + val_folder="data/KITTI_Tiny/Kitti/val", +) + +# 2. Build the task +model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer( + max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count() +) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict(["data/KITTI_Tiny/Kitti/predict/scans/000000.bin"]) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") + +# 6. Optional Visualize +app = launch_app(datamodule) +# app.show_train_dataset() +app.show_predictions(predictions) diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py new file mode 100644 index 0000000000..c50ea7b958 --- /dev/null +++ b/flash_examples/visualizations/pointcloud_segmentation.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.data.utils import download_data +from flash.pointcloud.segmentation import launch_app, PointCloudSegmentation, PointCloudSegmentationData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") + +datamodule = PointCloudSegmentationData.from_folders( + train_folder="data/SemanticKittiTiny/train", + val_folder="data/SemanticKittiTiny/val", +) + +# 2. Build the task +model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer( + max_epochs=1, limit_train_batches=0, limit_val_batches=0, num_sanity_val_steps=0, gpus=torch.cuda.device_count() +) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict( + [ + "data/SemanticKittiTiny/predict/000000.bin", + "data/SemanticKittiTiny/predict/000001.bin", + ] +) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") + +# 6. Optional Visualize +app = launch_app(datamodule) +# app.show_train_dataset() +app.show_predictions(predictions) diff --git a/pyproject.toml b/pyproject.toml index cbfacb0aeb..e18a6fbac5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,6 @@ [tool.autopep8] ignore = ["E731"] + + +[tool.black] +line-length = 120 diff --git a/requirements.txt b/requirements.txt index 01330917d4..e367ff1793 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ -torch>=1.8 +packaging +torch torchmetrics -pytorch-lightning>=1.3.1 +pytorch-lightning>=1.4.0 pyDeprecate -PyYAML>=5.1 -numpy pandas<1.3.0 -packaging -tqdm +jsonargparse[signatures]>=3.17.0 +click>=7.1.2 diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt new file mode 100644 index 0000000000..4c198da250 --- /dev/null +++ b/requirements/datatype_audio.txt @@ -0,0 +1,4 @@ +torchaudio +soundfile>=0.10.2 +transformers>=4.5 +datasets>=1.8 diff --git a/requirements/datatype_graph.txt b/requirements/datatype_graph.txt new file mode 100644 index 0000000000..9109e2167f --- /dev/null +++ b/requirements/datatype_graph.txt @@ -0,0 +1,3 @@ +torch-scatter +torch-sparse +torch-geometric diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index ab91d28d57..aa9fe14c15 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -3,7 +3,8 @@ timm>=0.4.5 lightning-bolts>=0.3.3 Pillow>=7.2 kornia>=0.5.1,<0.5.4 -matplotlib -pycocotools>=2.0.2 ; python_version >= "3.7" -fiftyone -pystiche>=0.7.2 +pystiche==1.* +segmentation-models-pytorch +icevision>=0.8 +icedata +effdet diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt new file mode 100644 index 0000000000..f61e3f9c25 --- /dev/null +++ b/requirements/datatype_image_extras.txt @@ -0,0 +1,2 @@ +matplotlib +fiftyone diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt new file mode 100644 index 0000000000..cc6437f44c --- /dev/null +++ b/requirements/datatype_pointcloud.txt @@ -0,0 +1,4 @@ +open3d==0.13 +torch==1.7.1 +torchvision +tensorboard diff --git a/requirements/datatype_video.txt b/requirements/datatype_video.txt index 85bc82a5df..28279e2293 100644 --- a/requirements/datatype_video.txt +++ b/requirements/datatype_video.txt @@ -1,5 +1,4 @@ torchvision Pillow>=7.2 kornia>=0.5.1,<0.5.4 -pytorchvideo==0.1.0 -fiftyone +pytorchvideo==0.1.2 diff --git a/requirements/datatype_video_extras.txt b/requirements/datatype_video_extras.txt new file mode 100644 index 0000000000..00de5ca1d2 --- /dev/null +++ b/requirements/datatype_video_extras.txt @@ -0,0 +1 @@ +fiftyone diff --git a/requirements/docs.txt b/requirements/docs.txt index a126cd5db3..5a6057f8e8 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ -sphinx>=4.0 +sphinx>=4.0,<4.1 recommonmark # fails with badges m2r # fails with multi-line text nbsphinx>=0.8 diff --git a/requirements/test.txt b/requirements/test.txt index 6a4674f7d9..3fecfe24d9 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -11,7 +11,6 @@ twine==3.2 # formatting pre-commit isort -yapf #mypy scikit-learn pytest_mock diff --git a/setup.cfg b/setup.cfg index 73aff69cad..8ed86d15f0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,18 +72,6 @@ ignore = .circleci -[yapf] -based_on_style = pep8 -spaces_before_comment = 2 -split_before_logical_operator = true -COLUMN_LIMIT = 120 -COALESCE_BRACKETS = true -DEDENT_CLOSING_BRACKETS = true -ALLOW_SPLIT_BEFORE_DICT_VALUE = false -BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true -NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false - - [mypy] # Typing tests is low priority, but enabling type checking on the # untyped test functions (using `--check-untyped-defs`) is still diff --git a/setup.py b/setup.py index d581ce9275..96fb1a6164 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,8 @@ def _load_py_module(fname, pkg="flash"): return py -about = _load_py_module('__about__.py') -setup_tools = _load_py_module('setup_tools.py') +about = _load_py_module("__about__.py") +setup_tools = _load_py_module("setup_tools.py") long_description = setup_tools._load_readme_description( _PATH_ROOT, @@ -49,13 +49,19 @@ def _load_py_module(fname, pkg="flash"): "text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_text.txt"), "tabular": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_tabular.txt"), "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image.txt"), + "image_extras": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image_extras.txt"), "video": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video.txt"), + "pointcloud": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_pointcloud.txt"), + "video_extras": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video_extras.txt"), "serve": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="serve.txt"), + "audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_audio.txt"), + "graph": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_graph.txt"), } -# remove possible duplicate. extras["vision"] = list(set(extras["image"] + extras["video"])) -extras["all"] = list(set(extras["vision"] + extras["tabular"] + extras["text"])) +extras["all"] = list( + set(extras["vision"] + extras["tabular"] + extras["text"]) +) # + extras["pointcloud"] dependencies conflicts extras["dev"] = list(set(extras["all"] + extras["test"] + extras["docs"])) # https://packaging.python.org/discussions/install-requires-vs-requirements / @@ -77,10 +83,13 @@ def _load_py_module(fname, pkg="flash"): long_description_content_type="text/markdown", include_package_data=True, extras_require=extras, + entry_points={ + "console_scripts": ["flash=flash.__main__:main"], + }, zip_safe=False, keywords=["deep learning", "pytorch", "AI"], python_requires=">=3.6", - install_requires=setup_tools._load_requirements(_PATH_ROOT, file_name='requirements.txt'), + install_requires=setup_tools._load_requirements(_PATH_ROOT, file_name="requirements.txt"), project_urls={ "Bug Tracker": "https://github.com/PyTorchLightning/lightning-flash/issues", "Documentation": "https://lightning-flash.rtfd.io/en/latest/", diff --git a/tests/__init__.py b/tests/__init__.py index c64310c910..2be74bcdc7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,5 +2,5 @@ # TorchVision hotfix https://github.com/pytorch/vision/issues/1938 opener = urllib.request.build_opener() -opener.addheaders = [('User-agent', 'Mozilla/5.0')] +opener.addheaders = [("User-agent", "Mozilla/5.0")] urllib.request.install_opener(opener) diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/__init__.py b/tests/audio/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py new file mode 100644 index 0000000000..626ca12b93 --- /dev/null +++ b/tests/audio/classification/test_data.py @@ -0,0 +1,340 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Any, List, Tuple + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from flash.audio import AudioClassificationData +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from tests.helpers.utils import _AUDIO_TESTING + +if _TORCHVISION_AVAILABLE: + import torchvision + +if _PIL_AVAILABLE: + from PIL import Image + + +def _rand_image(size: Tuple[int, int] = None): + if size is None: + _size = np.random.choice([196, 244]) + size = (_size, _size) + return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_smoke(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") + + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[1, 2], + batch_size=2, + num_workers=0, + ) + assert spectrograms_data.train_dataloader() is not None + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + assert sorted(list(labels.numpy())) == [1, 2] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_list_image_paths(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here + assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here + + # check validation data + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + assert list(labels.numpy()) == [1, 4] + + # check test data + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + assert list(labels.numpy()) == [2, 5] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") +def test_from_filepaths_visualise(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + # dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") +def test_from_filepaths_visualise_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + dm = AudioClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[[0, 1, 0], [0, 1, 1]], + val_files=[image_b, image_a], + val_targets=[[1, 1, 0], [0, 0, 1]], + test_files=[image_b, image_b], + test_targets=[[0, 0, 1], [1, 1, 0]], + batch_size=2, + spectrogram_size=(64, 64), + ) + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch("to_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_val_batch("per_batch_transform") + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_splits(tmpdir): + tmpdir = Path(tmpdir) + + B, _, H, W = 2, 3, 224, 224 + img_size: Tuple[int, int] = (H, W) + + (tmpdir / "splits").mkdir() + _rand_image(img_size).save(tmpdir / "s.png") + + num_samples: int = 10 + val_split: float = 0.3 + + train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] + + train_labels: List[int] = list(range(num_samples)) + + assert len(train_filepaths) == len(train_labels) + + _to_tensor = { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + } + + def run(transform: Any = None): + dm = AudioClassificationData.from_files( + train_files=train_filepaths, + train_targets=train_labels, + train_transform=transform, + val_transform=transform, + batch_size=B, + num_workers=0, + val_split=val_split, + spectrogram_size=img_size, + ) + data = next(iter(dm.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (B, 3, H, W) + assert labels.shape == (B,) + + run(_to_tensor) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_folders_only_train(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + + spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 128, 128) + assert labels.shape == (1,) + + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_folders_train_val(tmpdir): + + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + spectrograms_data = AudioClassificationData.from_folders( + train_dir, + val_folder=train_dir, + test_folder=train_dir, + batch_size=2, + num_workers=0, + ) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + assert list(labels.numpy()) == [0, 0] + + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2,) + assert list(labels.numpy()) == [0, 0] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + _rand_image().save(tmpdir / "a1.png") + _rand_image().save(tmpdir / "a2.png") + + train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")] + train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]] + valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] + test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=train_labels, + val_files=train_images, + val_targets=valid_labels, + test_files=train_images, + test_targets=test_labels, + batch_size=2, + num_workers=0, + ) + + data = next(iter(dm.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 4) + + data = next(iter(dm.val_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) + + data = next(iter(dm.test_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(test_labels)) diff --git a/tests/audio/classification/test_model.py b/tests/audio/classification/test_model.py new file mode 100644 index 0000000000..0e5a4fa3fc --- /dev/null +++ b/tests/audio/classification/test_model.py @@ -0,0 +1,31 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest + +from flash.__main__ import main +from flash.core.utilities.imports import _IMAGE_AVAILABLE +from tests.helpers.utils import _AUDIO_TESTING + + +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "audio_classification", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/audio/speech_recognition/__init__.py b/tests/audio/speech_recognition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py new file mode 100644 index 0000000000..6205da309d --- /dev/null +++ b/tests/audio/speech_recognition/test_data.py @@ -0,0 +1,89 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from pathlib import Path + +import pytest + +import flash +from flash.audio import SpeechRecognitionData +from flash.core.data.data_source import DefaultDataKeys +from tests.helpers.utils import _AUDIO_TESTING + +path = str(Path(flash.ASSETS_ROOT) / "example.wav") +sample = {"file": path, "text": "example input."} + +TEST_CSV_DATA = f"""file,text +{path},example input. +{path},example input. +{path},example input. +{path},example input. +{path},example input. +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir, n_samples=5): + path = Path(tmpdir) / "data.json" + with path.open("w") as f: + f.write("\n".join([json.dumps(sample) for x in range(n_samples)])) + return path + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +def test_from_csv(tmpdir): + csv_path = csv_data(tmpdir) + dm = SpeechRecognitionData.from_csv("file", "text", train_file=csv_path, batch_size=1, num_workers=0) + batch = next(iter(dm.train_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +def test_stage_test_and_valid(tmpdir): + csv_path = csv_data(tmpdir) + dm = SpeechRecognitionData.from_csv( + "file", "text", train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, num_workers=0 + ) + batch = next(iter(dm.val_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + batch = next(iter(dm.test_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +def test_from_json(tmpdir): + json_path = json_data(tmpdir) + dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0) + batch = next(iter(dm.train_dataloader())) + assert DefaultDataKeys.INPUT in batch + assert DefaultDataKeys.TARGET in batch + + +@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.") +def test_audio_module_not_found_error(): + with pytest.raises(ModuleNotFoundError, match="[audio]"): + SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0) diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py new file mode 100644 index 0000000000..eda3ac86b3 --- /dev/null +++ b/tests/audio/speech_recognition/test_data_model_integration.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from pathlib import Path + +import pytest +from pytorch_lightning import Trainer + +import flash +from flash.audio import SpeechRecognition, SpeechRecognitionData +from tests.helpers.utils import _AUDIO_TESTING + +TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing + +path = str(Path(flash.ASSETS_ROOT) / "example.wav") +sample = {"file": path, "text": "example input."} + +TEST_CSV_DATA = f"""file,text +{path},example input. +{path},example input. +{path},example input. +{path},example input. +{path},example input. +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir, n_samples=5): + path = Path(tmpdir) / "data.json" + with path.open("w") as f: + f.write("\n".join([json.dumps(sample) for x in range(n_samples)])) + return path + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_classification_csv(tmpdir): + csv_path = csv_data(tmpdir) + + data = SpeechRecognitionData.from_csv( + "file", + "text", + train_file=csv_path, + num_workers=0, + batch_size=2, + ) + model = SpeechRecognition(backbone=TEST_BACKBONE) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, datamodule=data) + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_classification_json(tmpdir): + json_path = json_data(tmpdir) + + data = SpeechRecognitionData.from_json( + "file", + "text", + train_file=json_path, + num_workers=0, + batch_size=2, + ) + model = SpeechRecognition(backbone=TEST_BACKBONE) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, datamodule=data) diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py new file mode 100644 index 0000000000..5ce932cd4d --- /dev/null +++ b/tests/audio/speech_recognition/test_model.py @@ -0,0 +1,102 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from unittest import mock + +import numpy as np +import pytest +import torch + +from flash import Trainer +from flash.__main__ import main +from flash.audio import SpeechRecognition +from flash.audio.speech_recognition.data import SpeechRecognitionPostprocess, SpeechRecognitionPreprocess +from flash.core.data.data_source import DefaultDataKeys +from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + def __getitem__(self, index): + return { + DefaultDataKeys.INPUT: np.random.randn(86631), + DefaultDataKeys.TARGET: "some target text", + DefaultDataKeys.METADATA: {"sampling_rate": 16000}, + } + + def __len__(self) -> int: + return 100 + + +# ============================== + +TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_init_train(tmpdir): + model = SpeechRecognition(backbone=TEST_BACKBONE) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_jit(tmpdir): + sample_input = {"input_values": torch.randn(size=torch.Size([1, 86631])).float()} + path = os.path.join(tmpdir, "test.pt") + + model = SpeechRecognition(backbone=TEST_BACKBONE) + model.eval() + + # Huggingface model only supports `torch.jit.trace` with `strict=False` + model = torch.jit.trace(model, sample_input, strict=False) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input)["logits"] + assert isinstance(out, torch.Tensor) + assert out.shape == torch.Size([1, 95, 12]) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = SpeechRecognition(backbone=TEST_BACKBONE) + # TODO: Currently only servable once a preprocess and postprocess have been attached + model._preprocess = SpeechRecognitionPreprocess() + model._postprocess = SpeechRecognitionPostprocess() + model.eval() + model.serve() + + +@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.") +def test_load_from_checkpoint_dependency_error(): + with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[audio]'")): + SpeechRecognition.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "speech_recognition", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/conftest.py b/tests/conftest.py index f2f67cc829..43fd8dc824 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ class UUID_String(str): - """Class to replace UUID object with str instance and hex attribute""" + """Class to replace UUID object with str instance and hex attribute.""" @property def hex(self): @@ -80,7 +80,7 @@ def lightning_squeezenet1_1_obj(): def squeezenet_servable(squeezenet1_1_model, session_global_datadir): from flash.core.serve import Servable - trace = torch.jit.trace(squeezenet1_1_model.eval(), (torch.rand(1, 3, 224, 224), )) + trace = torch.jit.trace(squeezenet1_1_model.eval(), (torch.rand(1, 3, 224, 224),)) fpth = str(session_global_datadir / "squeezenet_jit_trace.pt") torch.jit.save(trace, fpth) diff --git a/tests/core/data/test_auto_dataset.py b/tests/core/data/test_auto_dataset.py index 7acbffe671..8571363a0a 100644 --- a/tests/core/data/test_auto_dataset.py +++ b/tests/core/data/test_auto_dataset.py @@ -22,7 +22,6 @@ class _AutoDatasetTestDataSource(DataSource): - def __init__(self, with_dset: bool): self._callbacks: List[FlashCallback] = [] self.load_data_count = 0 diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index 20d2084b9b..9af754eb1c 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -37,7 +37,6 @@ def _rand_image(): class CustomBaseVisualization(BaseVisualization): - def __init__(self): super().__init__() @@ -77,7 +76,6 @@ def check_reset(self): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") class TestBaseViz: - def test_base_viz(self, tmpdir): seed_everything(42) @@ -89,7 +87,6 @@ def test_base_viz(self, tmpdir): _rand_image().save(train_images[1]) class CustomImageClassificationData(ImageClassificationData): - @staticmethod def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: return CustomBaseVisualization(*args, **kwargs) @@ -154,7 +151,7 @@ def _get_result(function_name: str): if not is_predict: res = _get_result("per_batch_transform") - assert res[0][DefaultDataKeys.TARGET].shape == (B, ) + assert res[0][DefaultDataKeys.TARGET].shape == (B,) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called @@ -165,12 +162,13 @@ def _get_result(function_name: str): dm.data_fetcher.check_reset() @pytest.mark.parametrize( - "func_names, valid", [ + "func_names, valid", + [ (["load_sample"], True), (["not_a_hook"], False), (["load_sample", "pre_tensor_transform"], True), (["load_sample", "not_a_hook"], True), - ] + ], ) def test_show(self, func_names, valid): base_viz = CustomBaseVisualization() diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index caba5cf4a0..a03457ed77 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -102,9 +102,9 @@ def test_tensor_batch(): def test_sequence(self): batch = { - 'a': torch.rand(self.BATCH_SIZE, 4), - 'b': torch.rand(self.BATCH_SIZE, 2), - 'c': torch.rand(self.BATCH_SIZE) + "a": torch.rand(self.BATCH_SIZE, 4), + "b": torch.rand(self.BATCH_SIZE, 2), + "c": torch.rand(self.BATCH_SIZE), } output = default_uncollate(batch) @@ -112,13 +112,13 @@ def test_sequence(self): assert len(batch) == self.BATCH_SIZE for sample in output: - assert list(sample.keys()) == ['a', 'b', 'c'] - assert isinstance(sample['a'], list) - assert len(sample['a']) == 4 - assert isinstance(sample['b'], list) - assert len(sample['b']) == 2 - assert isinstance(sample['c'], torch.Tensor) - assert len(sample['c'].shape) == 0 + assert list(sample.keys()) == ["a", "b", "c"] + assert isinstance(sample["a"], list) + assert len(sample["a"]) == 4 + assert isinstance(sample["b"], list) + assert len(sample["b"]) == 2 + assert isinstance(sample["c"], torch.Tensor) + assert len(sample["c"].shape) == 0 def test_named_tuple(self): Batch = namedtuple("Batch", ["x", "y"]) diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index e11591f33a..5db55dee08 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -23,8 +23,9 @@ from flash.core.trainer import Trainer +@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error @mock.patch("torch.save") # need to mock torch.save or we get pickle error -def test_flash_callback(_, tmpdir): +def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" callback_mock = MagicMock() @@ -47,7 +48,6 @@ def test_flash_callback(_, tmpdir): ] class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -91,5 +91,5 @@ def __init__(self): call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), - call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING) + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), ] diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index 284de09b02..07e89fec16 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -23,9 +23,7 @@ def test_base_data_fetcher(tmpdir): - class CheckData(BaseDataFetcher): - def check(self): assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] @@ -38,7 +36,6 @@ def check(self): assert self.batches["predict"] == {} class CustomDataModule(DataModule): - @staticmethod def configure_data_fetcher(): return CheckData() @@ -70,13 +67,11 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat data_fetcher.check() data_fetcher.reset() - assert data_fetcher.batches == {'train': {}, 'test': {}, 'val': {}, 'predict': {}} + assert data_fetcher.batches == {"train": {}, "test": {}, "val": {}, "predict": {}} def test_data_loaders_num_workers_to_0(tmpdir): - """ - num_workers should be set to `0` internally for visualization and not for training. - """ + """num_workers should be set to `0` internally for visualization and not for training.""" datamodule = DataModule(train_dataset=range(10), num_workers=3) iterator = datamodule._reset_iterator(RunningStage.TRAINING) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 2b593cdd9e..7124675f30 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -44,7 +44,6 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: return torch.rand(1), torch.rand(1) @@ -53,7 +52,6 @@ def __len__(self) -> int: class TestDataPipelineState: - @staticmethod def test_str(): state = DataPipelineState() @@ -95,9 +93,7 @@ def test_data_pipeline_str(): @pytest.mark.parametrize("use_preprocess", [False, True]) @pytest.mark.parametrize("use_postprocess", [False, True]) def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir): - class CustomModel(Task): - def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess @@ -135,9 +131,7 @@ class SubPostprocess(Postprocess): def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): - class CustomPreprocess(DefaultPreprocess): - def val_pre_tensor_transform(self, *_, **__): pass @@ -258,7 +252,6 @@ def test_per_batch_transform_on_device(self, *_, **__): class CustomPreprocess(DefaultPreprocess): - def train_per_sample_transform(self, *_, **__): pass @@ -307,9 +300,7 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor(): def test_detach_preprocessing_from_model(tmpdir): - class CustomModel(Task): - def __init__(self, postprocess: Optional[Postprocess] = None): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = postprocess @@ -333,7 +324,6 @@ def train_dataloader(self) -> Any: class TestPreprocess(DefaultPreprocess): - def train_per_sample_transform(self, *_, **__): pass @@ -363,7 +353,6 @@ def predict_per_batch_transform_on_device(self, *_, **__): def test_attaching_datapipeline_to_model(tmpdir): - class SubPreprocess(DefaultPreprocess): pass @@ -371,7 +360,6 @@ class SubPreprocess(DefaultPreprocess): data_pipeline = DataPipeline(preprocess=preprocess) class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) self._postprocess = Postprocess() @@ -513,8 +501,7 @@ def test_stage_orchestrator_state_attach_detach(tmpdir): _original_predict_step = model.predict_step class CustomDataPipeline(DataPipeline): - - def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postprocessor) -> 'Task': + def _attach_postprocess_to_model(self, model: "Task", _postprocesssor: _Postprocessor) -> "Task": model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model) return model @@ -528,7 +515,6 @@ def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postproc class LamdaDummyDataset(torch.utils.data.Dataset): - def __init__(self, fx: Callable): self.fx = fx @@ -540,7 +526,6 @@ def __len__(self) -> int: class TestPreprocessTransformationsDataSource(DataSource): - def __init__(self): super().__init__() @@ -589,7 +574,7 @@ def test_load_data(self, sample) -> LamdaDummyDataset: @staticmethod def fn_predict_load_data() -> List[str]: - return (["a", "b"]) + return ["a", "b"] def predict_load_data(self, sample) -> LamdaDummyDataset: assert self.predicting @@ -599,7 +584,6 @@ def predict_load_data(self, sample) -> LamdaDummyDataset: class TestPreprocessTransformations(DefaultPreprocess): - def __init__(self): super().__init__(data_sources={"default": TestPreprocessTransformationsDataSource()}) @@ -616,7 +600,7 @@ def train_pre_tensor_transform(self, sample: Any) -> Any: assert self.training assert self.current_fn == "pre_tensor_transform" self.train_pre_tensor_transform_called = True - return sample + (5, ) + return sample + (5,) def train_collate(self, samples) -> Tensor: assert self.training @@ -640,9 +624,9 @@ def val_collate(self, samples) -> Dict[str, Tensor]: assert self.validating assert self.current_fn == "collate" self.val_collate_called = True - _count = samples[0]['a'] - assert samples == [{'a': _count, 'b': _count + 1}, {'a': _count + 1, 'b': _count + 2}] - return {'a': tensor([0, 1]), 'b': tensor([1, 2])} + _count = samples[0]["a"] + assert samples == [{"a": _count, "b": _count + 1}, {"a": _count + 1, "b": _count + 2}] + return {"a": tensor([0, 1]), "b": tensor([1, 2])} def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert self.validating @@ -668,14 +652,12 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: class TestPreprocessTransformations2(TestPreprocessTransformations): - def val_to_tensor_transform(self, sample: Any) -> Tensor: self.val_to_tensor_transform_called = True return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -691,11 +673,11 @@ def test_step(self, batch, batch_idx): assert len(batch) == 2 assert batch[0].shape == torch.Size([2, 1]) - def predict_step(self, batch, batch_idx, dataloader_idx): - assert batch[0][0] == 'a' - assert batch[0][1] == 'a' - assert batch[1][0] == 'b' - assert batch[1][1] == 'b' + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch[0][0] == "a" + assert batch[0][1] == "a" + assert batch[1][0] == "b" + assert batch[1][1] == "b" return tensor([0, 0, 0]) @@ -709,8 +691,8 @@ def test_datapipeline_transformations(tmpdir): batch = next(iter(datamodule.train_dataloader())) assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) - assert datamodule.val_dataloader().dataset[0] == {'a': 0, 'b': 1} - assert datamodule.val_dataloader().dataset[1] == {'a': 1, 'b': 2} + assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1} + assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2} with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): batch = next(iter(datamodule.val_dataloader())) @@ -728,7 +710,7 @@ def test_datapipeline_transformations(tmpdir): limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, - num_sanity_val_steps=1 + num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) trainer.test(model) @@ -752,9 +734,7 @@ def test_datapipeline_transformations(tmpdir): def test_is_overriden_recursive(tmpdir): - class TestPreprocess(DefaultPreprocess): - def collate(self, *_): pass @@ -775,9 +755,7 @@ def val_collate(self, *_): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @patch("torch.save") # need to mock torch.save or we get pickle error def test_dummy_example(tmpdir): - class ImageDataSource(DataSource): - def load_data(self, folder: str): # from folder -> return files paths return ["a.jpg", "b.jpg"] @@ -788,7 +766,6 @@ def load_sample(self, path: str) -> Image.Image: return Image.fromarray(img8Bit) class ImageClassificationPreprocess(DefaultPreprocess): - def __init__( self, train_transform=None, @@ -817,7 +794,6 @@ def train_per_sample_transform_on_device(self, sample: Any) -> Any: return self._train_per_sample_transform_on_device(sample) class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -856,17 +832,15 @@ class CustomDataModule(DataModule): limit_val_batches=1, limit_test_batches=2, limit_predict_batches=2, - num_sanity_val_steps=1 + num_sanity_val_steps=1, ) trainer.fit(model, datamodule=datamodule) trainer.test(model) def test_preprocess_transforms(tmpdir): - """ - This test makes sure that when a preprocess is being provided transforms as dictionaries, - checking is done properly, and collate_in_worker_from_transform is properly extracted. - """ + """This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done + properly, and collate_in_worker_from_transform is properly extracted.""" with pytest.raises(MisconfigurationException, match="Transform should be a dict."): DefaultPreprocess(train_transform="choco") @@ -885,13 +859,13 @@ def test_preprocess_transforms(tmpdir): preprocess = DefaultPreprocess( train_transform={ "per_batch_transform": torch.nn.Linear(1, 1), - "per_sample_transform_on_device": torch.nn.Linear(1, 1) + "per_sample_transform_on_device": torch.nn.Linear(1, 1), } ) preprocess = DefaultPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, - predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)} + predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, ) # keep is None assert preprocess._train_collate_in_worker_from_transform is True @@ -910,7 +884,6 @@ def test_preprocess_transforms(tmpdir): assert predict_preprocessor.collate_fn.func == DataPipeline._identity class CustomPreprocess(DefaultPreprocess): - def per_sample_transform_on_device(self, sample: Any) -> Any: return super().per_sample_transform_on_device(sample) @@ -919,7 +892,7 @@ def per_batch_transform(self, batch: Any) -> Any: preprocess = CustomPreprocess( train_transform={"per_batch_transform": torch.nn.Linear(1, 1)}, - predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)} + predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}, ) # keep is None assert preprocess._train_collate_in_worker_from_transform is True @@ -941,9 +914,7 @@ def per_batch_transform(self, batch: Any) -> Any: def test_iterable_auto_dataset(tmpdir): - class CustomDataSource(DataSource): - def load_sample(self, index: int) -> Dict[str, int]: return {"index": index} @@ -954,7 +925,6 @@ def load_sample(self, index: int) -> Dict[str, int]: class CustomPreprocessHyperparameters(DefaultPreprocess): - def __init__(self, token: str, *args, **kwargs): self.token = token super().__init__(*args, **kwargs) diff --git a/tests/core/data/test_data_source.py b/tests/core/data/test_data_source.py index 77dbb173be..24a0b875fc 100644 --- a/tests/core/data/test_data_source.py +++ b/tests/core/data/test_data_source.py @@ -17,7 +17,7 @@ def test_dataset_data_source(): data_source = DatasetDataSource() - input, target = 'test', 3 + input, target = "test", 3 assert data_source.load_sample((input, target)) == {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} assert data_source.load_sample(input) == {DefaultDataKeys.INPUT: input} diff --git a/tests/core/data/test_process.py b/tests/core/data/test_process.py index 2e834fd666..509bbce3f8 100644 --- a/tests/core/data/test_process.py +++ b/tests/core/data/test_process.py @@ -33,41 +33,43 @@ def test_serializer(): my_serializer = Serializer() - assert my_serializer.serialize('test') == 'test' + assert my_serializer.serialize("test") == "test" my_serializer.serialize = Mock() my_serializer.disable() - assert my_serializer('test') == 'test' + assert my_serializer("test") == "test" my_serializer.serialize.assert_not_called() my_serializer.enable() - my_serializer('test') + my_serializer("test") my_serializer.serialize.assert_called_once() def test_serializer_mapping(): - """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers. Also checks that - state is retrieved / loaded correctly.""" + """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers. + + Also checks that state is retrieved / loaded correctly. + """ serializer1 = Serializer() - serializer1.serialize = Mock(return_value='test1') + serializer1.serialize = Mock(return_value="test1") class Serializer1State(ProcessState): pass serializer2 = Serializer() - serializer2.serialize = Mock(return_value='test2') + serializer2.serialize = Mock(return_value="test2") class Serializer2State(ProcessState): pass - serializer_mapping = SerializerMapping({'key1': serializer1, 'key2': serializer2}) - assert serializer_mapping({'key1': 'serializer1', 'key2': 'serializer2'}) == {'key1': 'test1', 'key2': 'test2'} - serializer1.serialize.assert_called_once_with('serializer1') - serializer2.serialize.assert_called_once_with('serializer2') + serializer_mapping = SerializerMapping({"key1": serializer1, "key2": serializer2}) + assert serializer_mapping({"key1": "serializer1", "key2": "serializer2"}) == {"key1": "test1", "key2": "test2"} + serializer1.serialize.assert_called_once_with("serializer1") + serializer2.serialize.assert_called_once_with("serializer2") - with pytest.raises(ValueError, match='output must be a mapping'): - serializer_mapping('not a mapping') + with pytest.raises(ValueError, match="output must be a mapping"): + serializer_mapping("not a mapping") serializer1_state = Serializer1State() serializer2_state = Serializer2State() @@ -87,10 +89,9 @@ class Serializer2State(ProcessState): def test_saving_with_serializers(tmpdir): - checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') + checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) @@ -110,7 +111,6 @@ def __init__(self): class CustomPreprocess(DefaultPreprocess): - def __init__(self): super().__init__( data_sources={ diff --git a/tests/core/data/test_sampler.py b/tests/core/data/test_sampler.py index 9ee9ace3a1..fd114d64f2 100644 --- a/tests/core/data/test_sampler.py +++ b/tests/core/data/test_sampler.py @@ -19,14 +19,14 @@ @mock.patch("flash.core.data.data_module.DataLoader") def test_dataloaders_with_sampler(mock_dataloader): - train_ds = val_ds = test_ds = 'dataset' - mock_sampler = 'sampler' + train_ds = val_ds = test_ds = "dataset" + mock_sampler = mock.MagicMock() dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler) assert dm.sampler is mock_sampler dl = dm.train_dataloader() kwargs = mock_dataloader.call_args[1] - assert 'sampler' in kwargs - assert kwargs['sampler'] is mock_sampler + assert "sampler" in kwargs + assert kwargs["sampler"] is mock_sampler.return_value for dl in [dm.val_dataloader(), dm.test_dataloader()]: kwargs = mock_dataloader.call_args[1] - assert 'sampler' not in kwargs + assert "sampler" not in kwargs diff --git a/tests/core/data/test_serialization.py b/tests/core/data/test_serialization.py index 5c368bb0b9..948f6bee13 100644 --- a/tests/core/data/test_serialization.py +++ b/tests/core/data/test_serialization.py @@ -25,13 +25,11 @@ class CustomModel(Task): - def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) class CustomPreprocess(DefaultPreprocess): - @classmethod def load_data(cls, data): return data @@ -40,8 +38,8 @@ def load_data(cls, data): def test_serialization_data_pipeline(tmpdir): model = CustomModel() - checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt') - checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt') + checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") + checkpoint = ModelCheckpoint(tmpdir, "test.ckpt") trainer = Trainer(callbacks=[checkpoint], max_epochs=1) dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) trainer.fit(model, dummy_data) @@ -69,5 +67,5 @@ def fn(*args, **kwargs): assert loaded_model.data_pipeline assert isinstance(loaded_model.preprocess, CustomPreprocess) for file in os.listdir(tmpdir): - if file.endswith('.ckpt'): + if file.endswith(".ckpt"): os.remove(os.path.join(tmpdir, file)) diff --git a/tests/core/data/test_splits.py b/tests/core/data/test_splits.py index 14e7f12993..0d58ed2228 100644 --- a/tests/core/data/test_splits.py +++ b/tests/core/data/test_splits.py @@ -28,7 +28,6 @@ def test_split_dataset(): assert len(np.unique(train_ds.indices)) == len(train_ds.indices) class Dataset: - def __init__(self): self.data = [0, 1, 2] self.name = "something" diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index f9239aa654..b66bd41cc8 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -23,40 +23,21 @@ class TestApplyToKeys: - @pytest.mark.parametrize( - "sample, keys, expected", [ - ({ - DefaultDataKeys.INPUT: "test" - }, DefaultDataKeys.INPUT, "test"), + "sample, keys, expected", + [ + ({DefaultDataKeys.INPUT: "test"}, DefaultDataKeys.INPUT, "test"), ( - { - DefaultDataKeys.INPUT: "test_a", - DefaultDataKeys.TARGET: "test_b" - }, + {DefaultDataKeys.INPUT: "test_a", DefaultDataKeys.TARGET: "test_b"}, [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], ["test_a", "test_b"], ), - ({ - "input": "test" - }, "input", "test"), - ({ - "input": "test_a", - "target": "test_b" - }, ["input", "target"], ["test_a", "test_b"]), - ({ - "input": "test_a", - "target": "test_b", - "extra": "..." - }, ["input", "target"], ["test_a", "test_b"]), - ({ - "input": "test_a", - "target": "test_b" - }, ["input", "target", "extra"], ["test_a", "test_b"]), - ({ - "target": "..." - }, "input", None), - ] + ({"input": "test"}, "input", "test"), + ({"input": "test_a", "target": "test_b"}, ["input", "target"], ["test_a", "test_b"]), + ({"input": "test_a", "target": "test_b", "extra": "..."}, ["input", "target"], ["test_a", "test_b"]), + ({"input": "test_a", "target": "test_b"}, ["input", "target", "extra"], ["test_a", "test_b"]), + ({"target": "..."}, "input", None), + ], ) def test_forward(self, sample, keys, expected): transform = Mock(return_value=["out"] * len(keys)) @@ -67,7 +48,8 @@ def test_forward(self, sample, keys, expected): transform.assert_not_called() @pytest.mark.parametrize( - "transform, expected", [ + "transform, expected", + [ ( ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.ReLU()), "ApplyToKeys(keys=, transform=ReLU())", @@ -82,7 +64,7 @@ def test_forward(self, sample, keys, expected): ApplyToKeys(["input", "target"], torch.nn.ReLU()), "ApplyToKeys(keys=['input', 'target'], transform=ReLU())", ), - ] + ], ) def test_repr(self, transform, expected): assert repr(transform) == expected @@ -118,18 +100,9 @@ def test_kornia_parallel_transforms(with_params): def test_kornia_collate(): samples = [ - { - DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), - DefaultDataKeys.TARGET: 1 - }, - { - DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), - DefaultDataKeys.TARGET: 2 - }, - { - DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), - DefaultDataKeys.TARGET: 3 - }, + {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 1}, + {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 2}, + {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 3}, ] result = kornia_collate(samples) @@ -145,24 +118,13 @@ def test_kornia_collate(): "base_transforms, additional_transforms, expected_result", [ ( - { - "to_tensor_transform": _MOCK_TRANSFORM - }, - { - "post_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM, - "post_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM}, + {"post_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, ), ( - { - "to_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM}, { "to_tensor_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) @@ -170,33 +132,23 @@ def test_kornia_collate(): }, ), ( - { - "to_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM, - "post_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, { "to_tensor_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ), - "post_tensor_transform": _MOCK_TRANSFORM + "post_tensor_transform": _MOCK_TRANSFORM, }, ), ( - { - "to_tensor_transform": _MOCK_TRANSFORM, - "post_tensor_transform": _MOCK_TRANSFORM - }, - { - "to_tensor_transform": _MOCK_TRANSFORM - }, + {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, + {"to_tensor_transform": _MOCK_TRANSFORM}, { "to_tensor_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ), - "post_tensor_transform": _MOCK_TRANSFORM + "post_tensor_transform": _MOCK_TRANSFORM, }, ), ], diff --git a/tests/core/optimizers/__init__.py b/tests/core/optimizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/optimizers/test_lr_shceduler.py b/tests/core/optimizers/test_lr_shceduler.py new file mode 100644 index 0000000000..922978b014 --- /dev/null +++ b/tests/core/optimizers/test_lr_shceduler.py @@ -0,0 +1,64 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import pytest +from torch import nn +from torch.optim import Adam + +from flash.core.optimizers import LinearWarmupCosineAnnealingLR + + +@pytest.mark.parametrize( + "lr, warmup_epochs, max_epochs, warmup_start_lr, eta_min", + [ + (1, 10, 3200, 0.001, 0.0), + (1e-4, 40, 300, 1e-6, 1e-5), + (0.01, 1, 10, 0.0, 0.0), + (0.01, 0, 10, 0.0, 0.0), # only cosine decay + (0.01, 10, 10, 0.0, 0.0), # only linear warmup + ], +) +def test_linear_warmup_cosine_annealing_lr(tmpdir, lr, warmup_epochs, max_epochs, warmup_start_lr, eta_min): + layer1 = nn.Linear(10, 1) + layer2 = nn.Linear(10, 1) + optimizer1 = Adam(layer1.parameters(), lr=lr) + optimizer2 = Adam(layer2.parameters(), lr=lr) + + scheduler1 = LinearWarmupCosineAnnealingLR( + optimizer1, + warmup_epochs=warmup_epochs, + max_epochs=max_epochs, + warmup_start_lr=warmup_start_lr, + eta_min=eta_min, + ) + + scheduler2 = LinearWarmupCosineAnnealingLR( + optimizer2, + warmup_epochs=warmup_epochs, + max_epochs=max_epochs, + warmup_start_lr=warmup_start_lr, + eta_min=eta_min, + ) + + # compares closed form lr values against values of get_lr function + for epoch in range(max_epochs): + scheduler1.step(epoch) + expected_lr = scheduler1.get_last_lr()[0] + current_lr = scheduler2.get_last_lr()[0] + + assert math.isclose(expected_lr, current_lr, rel_tol=1e-12) + optimizer1.step() + optimizer2.step() + scheduler2.step() diff --git a/tests/core/optimizers/test_optimizers.py b/tests/core/optimizers/test_optimizers.py new file mode 100644 index 0000000000..1413b762bc --- /dev/null +++ b/tests/core/optimizers/test_optimizers.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from torch import nn + +from flash.core.optimizers import LAMB, LARS, LinearWarmupCosineAnnealingLR + + +@pytest.mark.parametrize( + "optim_fn, lr, kwargs", + [ + (LARS, 0.1, {}), + (LARS, 0.1, {"weight_decay": 0.001}), + (LARS, 0.1, {"momentum": 0.9}), + (LAMB, 1e-3, {}), + (LAMB, 1e-3, {"amsgrad": True}), + (LAMB, 1e-3, {"weight_decay": 0.001}), + ], +) +def test_optim_call(tmpdir, optim_fn, lr, kwargs): + layer = nn.Linear(10, 1) + optimizer = optim_fn(layer.parameters(), lr=lr, **kwargs) + + for _ in range(10): + dummy_input = torch.rand(1, 10) + dummy_input.requires_grad = True + result = layer(dummy_input) + result.backward() + optimizer.step() + + +@pytest.mark.parametrize("optim_fn, lr", [(LARS, 0.1), (LAMB, 1e-3)]) +def test_optim_with_scheduler(tmpdir, optim_fn, lr): + max_epochs = 10 + layer = nn.Linear(10, 1) + optimizer = optim_fn(layer.parameters(), lr=lr) + scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=2, max_epochs=max_epochs) + + for _ in range(max_epochs): + dummy_input = torch.rand(1, 10) + dummy_input.requires_grad = True + result = layer(dummy_input) + result.backward() + optimizer.step() + scheduler.step() diff --git a/tests/core/serve/models.py b/tests/core/serve/models.py index 63f99327f7..9e0e914c41 100644 --- a/tests/core/serve/models.py +++ b/tests/core/serve/models.py @@ -14,7 +14,6 @@ class LightningSqueezenet(pl.LightningModule): - def __init__(self): super().__init__() self.model = squeezenet1_1(pretrained=True).eval() @@ -24,7 +23,6 @@ def forward(self, x): class LightningSqueezenetServable(pl.LightningModule): - def __init__(self, model): super().__init__() self.model = model @@ -38,7 +36,6 @@ def _func_from_exposed(arg): class ClassificationInference(ModelComponent): - def __init__(self, model): # skipcq: PYL-W0621 self.model = model @@ -73,7 +70,6 @@ def method_from_exposed(arg): try: class ClassificationInferenceRepeated(ModelComponent): - def __init__(self, model): self.model = model @@ -92,13 +88,14 @@ def classify(self, img): img = img.permute(0, 3, 2, 1) out = self.model(img) return ([out.argmax(), out.argmax()], torch.Tensor([21])) + + except TypeError: ClassificationInferenceRepeated = None try: class ClassificationInferenceModelSequence(ModelComponent): - def __init__(self, model): self.model1 = model[0] self.model2 = model[1] @@ -117,13 +114,14 @@ def classify(self, img): out2 = self.model2(img) assert out.argmax() == out2.argmax() return out.argmax() + + except TypeError: ClassificationInferenceRepeated = None try: class ClassificationInferenceModelMapping(ModelComponent): - def __init__(self, model): self.model1 = model["model_one"] self.model2 = model["model_two"] @@ -142,13 +140,14 @@ def classify(self, img): out2 = self.model2(img) assert out.argmax() == out2.argmax() return out.argmax() + + except TypeError: ClassificationInferenceModelMapping = None try: class ClassificationInferenceComposable(ModelComponent): - def __init__(self, model): self.model = model @@ -171,13 +170,14 @@ def classify(self, img, tag): out = self.model(img_new) return out.argmax(), img + + except TypeError: ClassificationInferenceComposable = None try: class SeatClassifier(ModelComponent): - def __init__(self, model, config): self.sport = config["sport"] @@ -197,5 +197,7 @@ def predict(self, section, isle, row, stadium): seat_num = section.item() * isle.item() * row.item() * stadium * len(self.sport) stadium_idx = torch.tensor(1000) return torch.Tensor([seat_num]), stadium_idx + + except TypeError: SeatClassifier = None diff --git a/tests/core/serve/test_compat/test_cached_property.py b/tests/core/serve/test_compat/test_cached_property.py index c6c909bdf8..b708fa8189 100644 --- a/tests/core/serve/test_compat/test_cached_property.py +++ b/tests/core/serve/test_compat/test_cached_property.py @@ -79,7 +79,6 @@ def cost(self): # noinspection PyStatementEffect @pytest.mark.skipif(sys.version_info >= (3, 8), reason="Python 3.8+ uses standard library implementation.") class TestCachedProperty: - @staticmethod def test_cached(): item = CachedCostItem() @@ -125,7 +124,6 @@ def test_object_with_slots(): @staticmethod def test_immutable_dict(): - class MyMeta(type): """Test metaclass.""" @@ -214,7 +212,6 @@ def test_doc(): @pytest.mark.skipif(sys.version_info < (3, 8), reason="Validate, that python 3.8 uses standard implementation") class TestPy38Plus: - @staticmethod def test_is(): import functools diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py index a32773726f..f31f89c84a 100644 --- a/tests/core/serve/test_components.py +++ b/tests/core/serve/test_components.py @@ -21,12 +21,14 @@ def test_model_compute_dependencies(lightning_squeezenet1_1_obj): comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp1.inputs.tag << comp2.outputs.predicted_tag - res = [{ - "source_component": "callnum_2", - "source_key": "predicted_tag", - "target_component": "callnum_1", - "target_key": "tag", - }] + res = [ + { + "source_component": "callnum_2", + "source_key": "predicted_tag", + "target_component": "callnum_1", + "target_key": "tag", + } + ] assert list(map(lambda x: x._asdict(), comp1._flashserve_meta_.connections)) == res assert list(comp2._flashserve_meta_.connections) == [] @@ -38,12 +40,14 @@ def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_ob comp2.outputs.predicted_tag >> comp1.inputs.tag - res = [{ - "source_component": "callnum_2", - "source_key": "predicted_tag", - "target_component": "callnum_1", - "target_key": "tag", - }] + res = [ + { + "source_component": "callnum_2", + "source_key": "predicted_tag", + "target_component": "callnum_1", + "target_key": "tag", + } + ] assert list(map(lambda x: x._asdict(), comp2._flashserve_meta_.connections)) == res assert list(comp1._flashserve_meta_.connections) == [] @@ -74,7 +78,6 @@ def test_two_component_invalid_dependencies_fail(lightning_squeezenet1_1_obj): comp2.outputs.predicted_tag >> comp1.outputs.predicted_tag class Foo: - def __init__(self): pass @@ -128,7 +131,6 @@ def test_invalid_expose_inputs(): with pytest.raises(SyntaxError, match="must be valid python attribute"): class ComposeClassInvalidExposeNameKeyword(ModelComponent): - def __init__(self, model): pass @@ -142,7 +144,6 @@ def predict(param): with pytest.raises(AttributeError, match="object has no attribute"): class ComposeClassInvalidExposeNameType(ModelComponent): - def __init__(self, model): pass @@ -156,7 +157,6 @@ def predict(param): with pytest.raises(TypeError, match="`expose` values must be"): class ComposeClassInvalidExposeInputsType(ModelComponent): - def __init__(self, model): pass @@ -170,7 +170,6 @@ def predict(param): with pytest.raises(ValueError, match="cannot set dict of length < 1"): class ComposeClassEmptyExposeInputsType(ModelComponent): - def __init__(self, model): pass @@ -206,7 +205,6 @@ def test_invalid_name(lightning_squeezenet1_1_obj): with pytest.raises(SyntaxError): class FailedExposedOutputsKeyworkName(ModelComponent): - def __init__(self, model): self.model = model @@ -222,7 +220,6 @@ def test_invalid_config_args(lightning_squeezenet1_1_obj): from flash.core.serve.types import Number class SomeComponent(ModelComponent): - def __init__(self, model, config=None): self.model = model self.config = config @@ -250,7 +247,6 @@ def test_invalid_model_args(lightning_squeezenet1_1_obj): from flash.core.serve.types import Number class SomeComponent(ModelComponent): - def __init__(self, model): self.model = model diff --git a/tests/core/serve/test_composition.py b/tests/core/serve/test_composition.py index 5679859ee2..c354e64f2f 100644 --- a/tests/core/serve/test_composition.py +++ b/tests/core/serve/test_composition.py @@ -23,10 +23,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj): actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()} assert actual_endpoints == { "classify_ENDPOINT": { - "inputs": { - "img": "callnum_1.inputs.img", - "tag": "callnum_1.inputs.tag" - }, + "inputs": {"img": "callnum_1.inputs.img", "tag": "callnum_1.inputs.tag"}, "outputs": { "cropped_img": "callnum_1.outputs.cropped_img", "predicted_tag": "callnum_1.outputs.predicted_tag", @@ -50,10 +47,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj): actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()} assert actual_endpoints == { "predict_ep": { - "inputs": { - "label_1": "callnum_1.inputs.img", - "tag_1": "callnum_1.inputs.tag" - }, + "inputs": {"label_1": "callnum_1.inputs.img", "tag_1": "callnum_1.inputs.tag"}, "outputs": { "cropped": "callnum_1.outputs.cropped_img", "prediction": "callnum_1.outputs.predicted_tag", @@ -381,21 +375,13 @@ def test_start_server_from_composition(tmp_path, squeezenet_servable, session_gl data = { "session": "session_uuid", "payload": { - "img_1": { - "data": cat_imgstr - }, - "img_2": { - "data": fish_imgstr - }, - "tag_1": { - "label": "stingray" - }, + "img_1": {"data": cat_imgstr}, + "img_2": {"data": fish_imgstr}, + "tag_1": {"label": "stingray"}, }, } expected_response = { - "result": { - "prediction": "goldfish, Carassius auratus" - }, + "result": {"prediction": "goldfish, Carassius auratus"}, "session": "session_uuid", } diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py index 238adcfa3c..673dce8106 100644 --- a/tests/core/serve/test_dag/test_optimization.py +++ b/tests/core/serve/test_dag/test_optimization.py @@ -36,7 +36,7 @@ def test_cull(): def fuse2(*args, **kwargs): - """Run both ``fuse`` and ``fuse_linear`` and compare results""" + """Run both ``fuse`` and ``fuse_linear`` and compare results.""" rv1 = fuse_linear(*args, **kwargs) if kwargs.get("rename_keys") is not False: return rv1 @@ -60,12 +60,14 @@ def test_fuse(): "b": 2, } assert fuse(d, rename_keys=False) == with_deps({"w": (inc, (inc, (inc, (add, "a", "b")))), "a": 1, "b": 2}) - assert fuse(d, rename_keys=True) == with_deps({ - "z-y-x-w": (inc, (inc, (inc, (add, "a", "b")))), - "a": 1, - "b": 2, - "w": "z-y-x-w", - }) + assert fuse(d, rename_keys=True) == with_deps( + { + "z-y-x-w": (inc, (inc, (inc, (add, "a", "b")))), + "a": 1, + "b": 2, + "w": "z-y-x-w", + } + ) d = { "NEW": (inc, "y"), @@ -76,22 +78,26 @@ def test_fuse(): "a": 1, "b": 2, } - assert fuse(d, rename_keys=False) == with_deps({ - "NEW": (inc, "y"), - "w": (inc, (inc, "y")), - "y": (inc, (add, "a", "b")), - "a": 1, - "b": 2, - }) - assert fuse(d, rename_keys=True) == with_deps({ - "NEW": (inc, "z-y"), - "x-w": (inc, (inc, "z-y")), - "z-y": (inc, (add, "a", "b")), - "a": 1, - "b": 2, - "w": "x-w", - "y": "z-y", - }) + assert fuse(d, rename_keys=False) == with_deps( + { + "NEW": (inc, "y"), + "w": (inc, (inc, "y")), + "y": (inc, (add, "a", "b")), + "a": 1, + "b": 2, + } + ) + assert fuse(d, rename_keys=True) == with_deps( + { + "NEW": (inc, "z-y"), + "x-w": (inc, (inc, "z-y")), + "z-y": (inc, (add, "a", "b")), + "a": 1, + "b": 2, + "w": "x-w", + "y": "z-y", + } + ) d = { "v": (inc, "y"), @@ -105,24 +111,28 @@ def test_fuse(): "c": 1, "d": 2, } - assert fuse(d, rename_keys=False) == with_deps({ - "u": (inc, (inc, (inc, "y"))), - "v": (inc, "y"), - "y": (inc, (add, "a", "b")), - "a": (inc, 1), - "b": (inc, 2), - }) - assert fuse(d, rename_keys=True) == with_deps({ - "x-w-u": (inc, (inc, (inc, "z-y"))), - "v": (inc, "z-y"), - "z-y": (inc, (add, "c-a", "d-b")), - "c-a": (inc, 1), - "d-b": (inc, 2), - "a": "c-a", - "b": "d-b", - "u": "x-w-u", - "y": "z-y", - }) + assert fuse(d, rename_keys=False) == with_deps( + { + "u": (inc, (inc, (inc, "y"))), + "v": (inc, "y"), + "y": (inc, (add, "a", "b")), + "a": (inc, 1), + "b": (inc, 2), + } + ) + assert fuse(d, rename_keys=True) == with_deps( + { + "x-w-u": (inc, (inc, (inc, "z-y"))), + "v": (inc, "z-y"), + "z-y": (inc, (add, "c-a", "d-b")), + "c-a": (inc, 1), + "d-b": (inc, 2), + "a": "c-a", + "b": "d-b", + "u": "x-w-u", + "y": "z-y", + } + ) d = { "a": (inc, "x"), @@ -132,20 +142,19 @@ def test_fuse(): "x": (inc, "y"), "y": 0, } - assert fuse(d, rename_keys=False) == with_deps({ - "a": (inc, "x"), - "b": (inc, "x"), - "d": (inc, (inc, "x")), - "x": (inc, 0) - }) - assert fuse(d, rename_keys=True) == with_deps({ - "a": (inc, "y-x"), - "b": (inc, "y-x"), - "c-d": (inc, (inc, "y-x")), - "y-x": (inc, 0), - "d": "c-d", - "x": "y-x", - }) + assert fuse(d, rename_keys=False) == with_deps( + {"a": (inc, "x"), "b": (inc, "x"), "d": (inc, (inc, "x")), "x": (inc, 0)} + ) + assert fuse(d, rename_keys=True) == with_deps( + { + "a": (inc, "y-x"), + "b": (inc, "y-x"), + "c-d": (inc, (inc, "y-x")), + "y-x": (inc, 0), + "d": "c-d", + "x": "y-x", + } + ) d = {"a": 1, "b": (inc, "a"), "c": (add, "b", "b")} assert fuse(d, rename_keys=False) == with_deps({"b": (inc, 1), "c": (add, "b", "b")}) @@ -168,21 +177,19 @@ def test_fuse_keys(): "b": 2, } keys = ["x", "z"] - assert fuse(d, keys, rename_keys=False) == with_deps({ - "w": (inc, "x"), - "x": (inc, (inc, "z")), - "z": (add, "a", "b"), - "a": 1, - "b": 2 - }) - assert fuse(d, keys, rename_keys=True) == with_deps({ - "w": (inc, "y-x"), - "y-x": (inc, (inc, "z")), - "z": (add, "a", "b"), - "a": 1, - "b": 2, - "x": "y-x", - }) + assert fuse(d, keys, rename_keys=False) == with_deps( + {"w": (inc, "x"), "x": (inc, (inc, "z")), "z": (add, "a", "b"), "a": 1, "b": 2} + ) + assert fuse(d, keys, rename_keys=True) == with_deps( + { + "w": (inc, "y-x"), + "y-x": (inc, (inc, "z")), + "z": (add, "a", "b"), + "a": 1, + "b": 2, + "x": "y-x", + } + ) def test_inline(): @@ -238,9 +245,7 @@ def test_inline_ignores_curries_and_partials(): def test_inline_functions_non_hashable(): - class NonHashableCallable: - def __call__(self, a): return a + 1 @@ -277,7 +282,6 @@ def test_inline_functions_protects_output_keys(): def test_functions_of(): - def a(x): return x @@ -290,7 +294,7 @@ def b(x): assert functions_of((a, [[[(b, 1)]]])) == {a, b} assert functions_of(1) == set() assert functions_of(a) == set() - assert functions_of((a, )) == {a} + assert functions_of((a,)) == {a} def test_inline_cull_dependencies(): @@ -301,7 +305,6 @@ def test_inline_cull_dependencies(): def test_fuse_reductions_single_input(): - def f(*args): return args @@ -309,11 +312,9 @@ def f(*args): assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "c": (f, (f, "a"), (f, "a", "a"))}) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c": (f, (f, "a"), (f, "a", "a")), - "c": "b1-b2-c" - }) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + {"a": 1, "b1-b2-c": (f, (f, "a"), (f, "a", "a")), "c": "b1-b2-c"} + ) d = { "a": 1, @@ -324,25 +325,24 @@ def f(*args): } assert fuse(d, ave_width=2.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=2.9, rename_keys=True) == with_deps(d) - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a": 1, - "c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")) - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-b3-c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")), - "c": "b1-b2-b3-c", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + {"a": 1, "c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a"))} + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-b3-c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")), + "c": "b1-b2-b3-c", + } + ) d = {"a": 1, "b1": (f, "a"), "b2": (f, "a"), "c": (f, "a", "b1", "b2")} assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "c": (f, "a", (f, "a"), (f, "a"))}) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c": (f, "a", (f, "a"), (f, "a")), - "c": "b1-b2-c" - }) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + {"a": 1, "b1-b2-c": (f, "a", (f, "a"), (f, "a")), "c": "b1-b2-c"} + ) d = { "a": 1, @@ -355,18 +355,18 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - assert fuse(d, ave_width=2, rename_keys=False) == with_deps({ - "a": 1, - "c": (f, (f, "a"), (f, "a")), - "e": (f, (f, "c"), (f, "c")) - }) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c": (f, (f, "a"), (f, "a")), - "d1-d2-e": (f, (f, "c"), (f, "c")), - "c": "b1-b2-c", - "e": "d1-d2-e", - }) + assert fuse(d, ave_width=2, rename_keys=False) == with_deps( + {"a": 1, "c": (f, (f, "a"), (f, "a")), "e": (f, (f, "c"), (f, "c"))} + ) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-c": (f, (f, "a"), (f, "a")), + "d1-d2-e": (f, (f, "c"), (f, "c")), + "c": "b1-b2-c", + "e": "d1-d2-e", + } + ) d = { "a": 1, @@ -380,37 +380,42 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - expected = with_deps({ - "a": 1, - "c1": (f, (f, "a"), (f, "a")), - "c2": (f, (f, "a"), (f, "a")), - "d": (f, "c1", "c2"), - }) + expected = with_deps( + { + "a": 1, + "c1": (f, (f, "a"), (f, "a")), + "c2": (f, (f, "a"), (f, "a")), + "d": (f, "c1", "c2"), + } + ) assert fuse(d, ave_width=2, rename_keys=False) == expected assert fuse(d, ave_width=2.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-c1": (f, (f, "a"), (f, "a")), - "b3-b4-c2": (f, (f, "a"), (f, "a")), - "d": (f, "c1", "c2"), - "c1": "b1-b2-c1", - "c2": "b3-b4-c2", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-c1": (f, (f, "a"), (f, "a")), + "b3-b4-c2": (f, (f, "a"), (f, "a")), + "d": (f, "c1", "c2"), + "c1": "b1-b2-c1", + "c2": "b3-b4-c2", + } + ) assert fuse(d, ave_width=2, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a": 1, - "d": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))) - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-b3-b4-c1-c2-d": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "d": "b1-b2-b3-b4-c1-c2-d", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + {"a": 1, "d": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a")))} + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-b3-b4-c1-c2-d": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "d": "b1-b2-b3-b4-c1-c2-d", + } + ) d = { "a": 1, @@ -432,77 +437,89 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - expected = with_deps({ - "a": 1, - "c1": (f, (f, "a"), (f, "a")), - "c2": (f, (f, "a"), (f, "a")), - "c3": (f, (f, "a"), (f, "a")), - "c4": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "e": (f, "d1", "d2"), - }) + expected = with_deps( + { + "a": 1, + "c1": (f, (f, "a"), (f, "a")), + "c2": (f, (f, "a"), (f, "a")), + "c3": (f, (f, "a"), (f, "a")), + "c4": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "e": (f, "d1", "d2"), + } + ) assert fuse(d, ave_width=2, rename_keys=False) == expected assert fuse(d, ave_width=2.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-c1": (f, (f, "a"), (f, "a")), - "b3-b4-c2": (f, (f, "a"), (f, "a")), - "b5-b6-c3": (f, (f, "a"), (f, "a")), - "b7-b8-c4": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "e": (f, "d1", "d2"), - "c1": "b1-b2-c1", - "c2": "b3-b4-c2", - "c3": "b5-b6-c3", - "c4": "b7-b8-c4", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-c1": (f, (f, "a"), (f, "a")), + "b3-b4-c2": (f, (f, "a"), (f, "a")), + "b5-b6-c3": (f, (f, "a"), (f, "a")), + "b7-b8-c4": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "e": (f, "d1", "d2"), + "c1": "b1-b2-c1", + "c2": "b3-b4-c2", + "c3": "b5-b6-c3", + "c4": "b7-b8-c4", + } + ) assert fuse(d, ave_width=2, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - expected = with_deps({ - "a": 1, - "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "e": (f, "d1", "d2"), - }) + expected = with_deps( + { + "a": 1, + "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "e": (f, "d1", "d2"), + } + ) assert fuse(d, ave_width=3, rename_keys=False) == expected assert fuse(d, ave_width=4.6, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-b3-b4-c1-c2-d1": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b5-b6-b7-b8-c3-c4-d2": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "e": (f, "d1", "d2"), - "d1": "b1-b2-b3-b4-c1-c2-d1", - "d2": "b5-b6-b7-b8-c3-c4-d2", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-b3-b4-c1-c2-d1": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b5-b6-b7-b8-c3-c4-d2": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "e": (f, "d1", "d2"), + "d1": "b1-b2-b3-b4-c1-c2-d1", + "d2": "b5-b6-b7-b8-c3-c4-d2", + } + ) assert fuse(d, ave_width=3, rename_keys=True) == expected assert fuse(d, ave_width=4.6, rename_keys=True) == expected - assert fuse(d, ave_width=4.7, rename_keys=False) == with_deps({ - "a": 1, - "e": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - }) - assert fuse(d, ave_width=4.7, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "e": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e", - }) + assert fuse(d, ave_width=4.7, rename_keys=False) == with_deps( + { + "a": 1, + "e": ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + } + ) + assert fuse(d, ave_width=4.7, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e": ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + "e": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e", + } + ) d = { "a": 1, @@ -540,165 +557,181 @@ def f(*args): } assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d) - expected = with_deps({ - "a": 1, - "c1": (f, (f, "a"), (f, "a")), - "c2": (f, (f, "a"), (f, "a")), - "c3": (f, (f, "a"), (f, "a")), - "c4": (f, (f, "a"), (f, "a")), - "c5": (f, (f, "a"), (f, "a")), - "c6": (f, (f, "a"), (f, "a")), - "c7": (f, (f, "a"), (f, "a")), - "c8": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "d3": (f, "c5", "c6"), - "d4": (f, "c7", "c8"), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - }) + expected = with_deps( + { + "a": 1, + "c1": (f, (f, "a"), (f, "a")), + "c2": (f, (f, "a"), (f, "a")), + "c3": (f, (f, "a"), (f, "a")), + "c4": (f, (f, "a"), (f, "a")), + "c5": (f, (f, "a"), (f, "a")), + "c6": (f, (f, "a"), (f, "a")), + "c7": (f, (f, "a"), (f, "a")), + "c8": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "d3": (f, "c5", "c6"), + "d4": (f, "c7", "c8"), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + } + ) assert fuse(d, ave_width=2, rename_keys=False) == expected assert fuse(d, ave_width=2.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-c1": (f, (f, "a"), (f, "a")), - "b3-b4-c2": (f, (f, "a"), (f, "a")), - "b5-b6-c3": (f, (f, "a"), (f, "a")), - "b7-b8-c4": (f, (f, "a"), (f, "a")), - "b10-b9-c5": (f, (f, "a"), (f, "a")), - "b11-b12-c6": (f, (f, "a"), (f, "a")), - "b13-b14-c7": (f, (f, "a"), (f, "a")), - "b15-b16-c8": (f, (f, "a"), (f, "a")), - "d1": (f, "c1", "c2"), - "d2": (f, "c3", "c4"), - "d3": (f, "c5", "c6"), - "d4": (f, "c7", "c8"), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - "c1": "b1-b2-c1", - "c2": "b3-b4-c2", - "c3": "b5-b6-c3", - "c4": "b7-b8-c4", - "c5": "b10-b9-c5", - "c6": "b11-b12-c6", - "c7": "b13-b14-c7", - "c8": "b15-b16-c8", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-c1": (f, (f, "a"), (f, "a")), + "b3-b4-c2": (f, (f, "a"), (f, "a")), + "b5-b6-c3": (f, (f, "a"), (f, "a")), + "b7-b8-c4": (f, (f, "a"), (f, "a")), + "b10-b9-c5": (f, (f, "a"), (f, "a")), + "b11-b12-c6": (f, (f, "a"), (f, "a")), + "b13-b14-c7": (f, (f, "a"), (f, "a")), + "b15-b16-c8": (f, (f, "a"), (f, "a")), + "d1": (f, "c1", "c2"), + "d2": (f, "c3", "c4"), + "d3": (f, "c5", "c6"), + "d4": (f, "c7", "c8"), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + "c1": "b1-b2-c1", + "c2": "b3-b4-c2", + "c3": "b5-b6-c3", + "c4": "b7-b8-c4", + "c5": "b10-b9-c5", + "c6": "b11-b12-c6", + "c7": "b13-b14-c7", + "c8": "b15-b16-c8", + } + ) assert fuse(d, ave_width=2, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - expected = with_deps({ - "a": 1, - "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d3": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "d4": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - }) + expected = with_deps( + { + "a": 1, + "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d3": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "d4": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + } + ) assert fuse(d, ave_width=3, rename_keys=False) == expected assert fuse(d, ave_width=4.6, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-b3-b4-c1-c2-d1": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b5-b6-b7-b8-c3-c4-d2": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b10-b11-b12-b9-c5-c6-d3": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "b13-b14-b15-b16-c7-c8-d4": ( - f, - (f, (f, "a"), (f, "a")), - (f, (f, "a"), (f, "a")), - ), - "e1": (f, "d1", "d2"), - "e2": (f, "d3", "d4"), - "f": (f, "e1", "e2"), - "d1": "b1-b2-b3-b4-c1-c2-d1", - "d2": "b5-b6-b7-b8-c3-c4-d2", - "d3": "b10-b11-b12-b9-c5-c6-d3", - "d4": "b13-b14-b15-b16-c7-c8-d4", - }) + expected = with_deps( + { + "a": 1, + "b1-b2-b3-b4-c1-c2-d1": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b5-b6-b7-b8-c3-c4-d2": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b10-b11-b12-b9-c5-c6-d3": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "b13-b14-b15-b16-c7-c8-d4": ( + f, + (f, (f, "a"), (f, "a")), + (f, (f, "a"), (f, "a")), + ), + "e1": (f, "d1", "d2"), + "e2": (f, "d3", "d4"), + "f": (f, "e1", "e2"), + "d1": "b1-b2-b3-b4-c1-c2-d1", + "d2": "b5-b6-b7-b8-c3-c4-d2", + "d3": "b10-b11-b12-b9-c5-c6-d3", + "d4": "b13-b14-b15-b16-c7-c8-d4", + } + ) assert fuse(d, ave_width=3, rename_keys=True) == expected assert fuse(d, ave_width=4.6, rename_keys=True) == expected - expected = with_deps({ - "a": 1, - "e1": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "e2": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "f": (f, "e1", "e2"), - }) - assert fuse(d, ave_width=4.7, rename_keys=False) == expected - assert fuse(d, ave_width=7.4, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2": ( - f, - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), - ), - "f": (f, "e1", "e2"), - "e1": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1", - "e2": "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2", - }) - assert fuse(d, ave_width=4.7, rename_keys=True) == expected - assert fuse(d, ave_width=7.4, rename_keys=True) == expected - assert fuse(d, ave_width=7.5, rename_keys=False) == with_deps({ - "a": 1, - "f": ( - f, - ( + expected = with_deps( + { + "a": 1, + "e1": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ( + "e2": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ), - }) - assert fuse(d, ave_width=7.5, rename_keys=True) == with_deps({ - "a": 1, - "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f": ( - f, - ( + "f": (f, "e1", "e2"), + } + ) + assert fuse(d, ave_width=4.7, rename_keys=False) == expected + assert fuse(d, ave_width=7.4, rename_keys=False) == expected + expected = with_deps( + { + "a": 1, + "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ( + "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2": ( f, (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), ), - ), - "f": "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f", - }) + "f": (f, "e1", "e2"), + "e1": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1", + "e2": "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2", + } + ) + assert fuse(d, ave_width=4.7, rename_keys=True) == expected + assert fuse(d, ave_width=7.4, rename_keys=True) == expected + assert fuse(d, ave_width=7.5, rename_keys=False) == with_deps( + { + "a": 1, + "f": ( + f, + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ), + } + ) + assert fuse(d, ave_width=7.5, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f": ( + f, + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ( + f, + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))), + ), + ), + "f": "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f", + } + ) d = {"a": 1, "b": (f, "a")} assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"b": (f, 1)}) @@ -710,11 +743,9 @@ def f(*args): d = {"a": 1, "b": (f, "a"), "c": (f, "a", "b"), "d": (f, "a", "c")} assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"a": 1, "d": (f, "a", (f, "a", (f, "a")))}) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a": 1, - "b-c-d": (f, "a", (f, "a", (f, "a"))), - "d": "b-c-d" - }) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + {"a": 1, "b-c-d": (f, "a", (f, "a", (f, "a"))), "d": "b-c-d"} + ) d = { "a": 1, @@ -728,21 +759,25 @@ def f(*args): expected = with_deps({"a": 1, "b2": (f, "a"), "e1": (f, (f, (f, (f, "a")))), "f": (f, "e1", "b2")}) assert fuse(d, ave_width=1, rename_keys=False) == expected assert fuse(d, ave_width=1.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b2": (f, "a"), - "b1-c1-d1-e1": (f, (f, (f, (f, "a")))), - "f": (f, "e1", "b2"), - "e1": "b1-c1-d1-e1", - }) + expected = with_deps( + { + "a": 1, + "b2": (f, "a"), + "b1-c1-d1-e1": (f, (f, (f, (f, "a")))), + "f": (f, "e1", "b2"), + "e1": "b1-c1-d1-e1", + } + ) assert fuse(d, ave_width=1, rename_keys=True) == expected assert fuse(d, ave_width=1.9, rename_keys=True) == expected assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "f": (f, (f, (f, (f, (f, "a")))), (f, "a"))}) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c1-d1-e1-f": (f, (f, (f, (f, (f, "a")))), (f, "a")), - "f": "b1-b2-c1-d1-e1-f", - }) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-c1-d1-e1-f": (f, (f, (f, (f, (f, "a")))), (f, "a")), + "f": "b1-b2-c1-d1-e1-f", + } + ) d = { "a": 1, @@ -753,37 +788,42 @@ def f(*args): "e1": (f, "a", "d1"), "f": (f, "a", "e1", "b2"), } - expected = with_deps({ - "a": 1, - "b2": (f, "a"), - "e1": (f, "a", (f, "a", (f, "a", (f, "a")))), - "f": (f, "a", "e1", "b2"), - }) + expected = with_deps( + { + "a": 1, + "b2": (f, "a"), + "e1": (f, "a", (f, "a", (f, "a", (f, "a")))), + "f": (f, "a", "e1", "b2"), + } + ) assert fuse(d, ave_width=1, rename_keys=False) == expected assert fuse(d, ave_width=1.9, rename_keys=False) == expected - expected = with_deps({ - "a": 1, - "b2": (f, "a"), - "b1-c1-d1-e1": (f, "a", (f, "a", (f, "a", (f, "a")))), - "f": (f, "a", "e1", "b2"), - "e1": "b1-c1-d1-e1", - }) + expected = with_deps( + { + "a": 1, + "b2": (f, "a"), + "b1-c1-d1-e1": (f, "a", (f, "a", (f, "a", (f, "a")))), + "f": (f, "a", "e1", "b2"), + "e1": "b1-c1-d1-e1", + } + ) assert fuse(d, ave_width=1, rename_keys=True) == expected assert fuse(d, ave_width=1.9, rename_keys=True) == expected - assert fuse(d, ave_width=2, rename_keys=False) == with_deps({ - "a": 1, - "f": (f, "a", (f, "a", (f, "a", (f, "a", (f, "a")))), (f, "a")) - }) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a": 1, - "b1-b2-c1-d1-e1-f": ( - f, - "a", - (f, "a", (f, "a", (f, "a", (f, "a")))), - (f, "a"), - ), - "f": "b1-b2-c1-d1-e1-f", - }) + assert fuse(d, ave_width=2, rename_keys=False) == with_deps( + {"a": 1, "f": (f, "a", (f, "a", (f, "a", (f, "a", (f, "a")))), (f, "a"))} + ) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a": 1, + "b1-b2-c1-d1-e1-f": ( + f, + "a", + (f, "a", (f, "a", (f, "a", (f, "a")))), + (f, "a"), + ), + "f": "b1-b2-c1-d1-e1-f", + } + ) d = { "a": 1, @@ -800,24 +840,28 @@ def f(*args): "f": (f, "e"), "g": (f, "f"), } - assert fuse(d, ave_width=1, rename_keys=False) == with_deps({ - "a": 1, - "d1": (f, (f, (f, "a"))), - "d2": (f, (f, (f, "a"))), - "d3": (f, (f, (f, "a"))), - "g": (f, (f, (f, "d1", "d2", "d3"))), - }) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a": 1, - "b1-c1-d1": (f, (f, (f, "a"))), - "b2-c2-d2": (f, (f, (f, "a"))), - "b3-c3-d3": (f, (f, (f, "a"))), - "e-f-g": (f, (f, (f, "d1", "d2", "d3"))), - "d1": "b1-c1-d1", - "d2": "b2-c2-d2", - "d3": "b3-c3-d3", - "g": "e-f-g", - }) + assert fuse(d, ave_width=1, rename_keys=False) == with_deps( + { + "a": 1, + "d1": (f, (f, (f, "a"))), + "d2": (f, (f, (f, "a"))), + "d3": (f, (f, (f, "a"))), + "g": (f, (f, (f, "d1", "d2", "d3"))), + } + ) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + { + "a": 1, + "b1-c1-d1": (f, (f, (f, "a"))), + "b2-c2-d2": (f, (f, (f, "a"))), + "b3-c3-d3": (f, (f, (f, "a"))), + "e-f-g": (f, (f, (f, "d1", "d2", "d3"))), + "d1": "b1-c1-d1", + "d2": "b2-c2-d2", + "d3": "b3-c3-d3", + "g": "e-f-g", + } + ) d = { "a": 1, @@ -828,23 +872,22 @@ def f(*args): "f": (f, "e"), "g": (f, "d", "f"), } - assert fuse(d, ave_width=1, rename_keys=False) == with_deps({ - "b": (f, 1), - "d": (f, "b", (f, "b")), - "g": (f, "d", (f, (f, "d"))) - }) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a-b": (f, 1), - "c-d": (f, "b", (f, "b")), - "e-f-g": (f, "d", (f, (f, "d"))), - "b": "a-b", - "d": "c-d", - "g": "e-f-g", - }) + assert fuse(d, ave_width=1, rename_keys=False) == with_deps( + {"b": (f, 1), "d": (f, "b", (f, "b")), "g": (f, "d", (f, (f, "d")))} + ) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + { + "a-b": (f, 1), + "c-d": (f, "b", (f, "b")), + "e-f-g": (f, "d", (f, (f, "d"))), + "b": "a-b", + "d": "c-d", + "g": "e-f-g", + } + ) def test_fuse_stressed(): - def f(*args): return args @@ -917,7 +960,6 @@ def f(*args): def test_fuse_reductions_multiple_input(): - def f(*args): return args @@ -925,12 +967,9 @@ def f(*args): assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"c": (f, (f, 1, 2))}) assert fuse(d, ave_width=2, rename_keys=True) == with_deps({"a1-a2-b-c": (f, (f, 1, 2)), "c": "a1-a2-b-c"}) assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"a1": 1, "a2": 2, "c": (f, (f, "a1", "a2"))}) - assert fuse(d, ave_width=1, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b-c": (f, (f, "a1", "a2")), - "c": "b-c" - }) + assert fuse(d, ave_width=1, rename_keys=True) == with_deps( + {"a1": 1, "a2": 2, "b-c": (f, (f, "a1", "a2")), "c": "b-c"} + ) d = { "a1": 1, @@ -945,17 +984,17 @@ def f(*args): assert fuse(d, ave_width=2.9, rename_keys=False) == expected assert fuse(d, ave_width=1, rename_keys=True) == expected assert fuse(d, ave_width=2.9, rename_keys=True) == expected - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a1": 1, - "a2": 2, - "c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")) - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b1-b2-b3-c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")), - "c": "b1-b2-b3-c", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + {"a1": 1, "a2": 2, "c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2"))} + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a1": 1, + "a2": 2, + "b1-b2-b3-c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")), + "c": "b1-b2-b3-c", + } + ) d = { "a1": 1, @@ -968,22 +1007,26 @@ def f(*args): } assert fuse(d, ave_width=1, rename_keys=False) == with_deps(d) assert fuse(d, ave_width=1, rename_keys=True) == with_deps(d) - assert fuse(d, ave_width=2, rename_keys=False) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "c1": (f, (f, "a1"), "b2"), - "c2": (f, "b2", (f, "a2")), - }) - assert fuse(d, ave_width=2, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "b1-c1": (f, (f, "a1"), "b2"), - "b3-c2": (f, "b2", (f, "a2")), - "c1": "b1-c1", - "c2": "b3-c2", - }) + assert fuse(d, ave_width=2, rename_keys=False) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "c1": (f, (f, "a1"), "b2"), + "c2": (f, "b2", (f, "a2")), + } + ) + assert fuse(d, ave_width=2, rename_keys=True) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "b1-c1": (f, (f, "a1"), "b2"), + "b3-c2": (f, "b2", (f, "a2")), + "c1": "b1-c1", + "c2": "b3-c2", + } + ) d = { "a1": 1, @@ -1000,19 +1043,23 @@ def f(*args): # A more aggressive heuristic could do this at `ave_width=2`. Perhaps # we can improve this. Nevertheless, this is behaving as intended. - assert fuse(d, ave_width=3, rename_keys=False) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), - }) - assert fuse(d, ave_width=3, rename_keys=True) == with_deps({ - "a1": 1, - "a2": 2, - "b2": (f, "a1", "a2"), - "b1-b3-c1-c2-d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), - "d": "b1-b3-c1-c2-d", - }) + assert fuse(d, ave_width=3, rename_keys=False) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), + } + ) + assert fuse(d, ave_width=3, rename_keys=True) == with_deps( + { + "a1": 1, + "a2": 2, + "b2": (f, "a1", "a2"), + "b1-b3-c1-c2-d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))), + "d": "b1-b3-c1-c2-d", + } + ) def func_with_kwargs(a, b, c=2): @@ -1028,20 +1075,13 @@ def test_SubgraphCallable(): apply, partial_by_order, ["in2"], - { - "function": func_with_kwargs, - "other": [(1, 20)], - "c": 4 - }, + {"function": func_with_kwargs, "other": [(1, 20)], "c": 4}, ), "c": ( apply, partial_by_order, ["in2", "in1"], - { - "function": func_with_kwargs, - "other": [(1, 20)] - }, + {"function": func_with_kwargs, "other": [(1, 20)]}, ), "d": (inc, "a"), "e": (add, "c", "d"), @@ -1105,54 +1145,60 @@ def test_fuse_subgraphs(): } res = fuse(dsk, "inc-6", fuse_subgraphs=True) - sol = with_deps({ - "inc-6": "add-inc-x-1", - "add-inc-x-1": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", (inc, (inc, "x-1"))), - "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), - }, - "inc-6", - (), + sol = with_deps( + { + "inc-6": "add-inc-x-1", + "add-inc-x-1": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", (inc, (inc, "x-1"))), + "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), + }, + "inc-6", + (), + ), ), - ), - }) + } + ) assert res == sol res = fuse(dsk, "inc-6", fuse_subgraphs=True, rename_keys=False) - sol = with_deps({ - "inc-6": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", (inc, (inc, "x-1"))), - "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), - }, - "inc-6", - (), - ), - ) - }) + sol = with_deps( + { + "inc-6": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", (inc, (inc, "x-1"))), + "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))), + }, + "inc-6", + (), + ), + ) + } + ) assert res == sol res = fuse(dsk, "add-2", fuse_subgraphs=True) - sol = with_deps({ - "add-inc-x-1": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", (inc, (inc, "x-1"))), - "add-2": (add, "add-1", (inc, (inc, "add-1"))), - }, - "add-2", - (), + sol = with_deps( + { + "add-inc-x-1": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", (inc, (inc, "x-1"))), + "add-2": (add, "add-1", (inc, (inc, "add-1"))), + }, + "add-2", + (), + ), ), - ), - "add-2": "add-inc-x-1", - "inc-6": (inc, (inc, "add-2")), - }) + "add-2": "add-inc-x-1", + "inc-6": (inc, (inc, "add-2")), + } + ) assert res == sol res = fuse(dsk, "inc-2", fuse_subgraphs=True) @@ -1160,24 +1206,27 @@ def test_fuse_subgraphs(): sols = [] for inkeys in itertools.permutations(("x-1", "inc-2")): sols.append( - with_deps({ - "x-1": 1, - "inc-2": (inc, (inc, "x-1")), - "inc-6": "inc-add-1", - "inc-add-1": ( - SubgraphCallable( - { - "add-1": (add, "x-1", "inc-2"), - "inc-6": ( - inc, - (inc, (add, "add-1", (inc, (inc, "add-1")))), - ), - }, - "inc-6", - inkeys, - ), - ) + inkeys, - }) + with_deps( + { + "x-1": 1, + "inc-2": (inc, (inc, "x-1")), + "inc-6": "inc-add-1", + "inc-add-1": ( + SubgraphCallable( + { + "add-1": (add, "x-1", "inc-2"), + "inc-6": ( + inc, + (inc, (add, "add-1", (inc, (inc, "add-1")))), + ), + }, + "inc-6", + inkeys, + ), + ) + + inkeys, + } + ) ) assert res in sols @@ -1186,22 +1235,25 @@ def test_fuse_subgraphs(): sols = [] for inkeys in itertools.permutations(("x-1", "inc-2")): sols.append( - with_deps({ - "x-1": 1, - "inc-2": (inc, (inc, "x-1")), - "inc-add-1": ( - SubgraphCallable( - { - "add-1": (add, "x-1", "inc-2"), - "add-2": (add, "add-1", (inc, (inc, "add-1"))), - }, - "add-2", - inkeys, - ), - ) + inkeys, - "add-2": "inc-add-1", - "inc-6": (inc, (inc, "add-2")), - }) + with_deps( + { + "x-1": 1, + "inc-2": (inc, (inc, "x-1")), + "inc-add-1": ( + SubgraphCallable( + { + "add-1": (add, "x-1", "inc-2"), + "add-2": (add, "add-1", (inc, (inc, "add-1"))), + }, + "add-2", + inkeys, + ), + ) + + inkeys, + "add-2": "inc-add-1", + "inc-6": (inc, (inc, "add-2")), + } + ) ) assert res in sols @@ -1217,31 +1269,30 @@ def test_fuse_subgraphs_linear_chains_of_duplicate_deps(): } res = fuse(dsk, "add-5", fuse_subgraphs=True) - sol = with_deps({ - "add-x-1": ( - SubgraphCallable( - { - "x-1": 1, - "add-1": (add, "x-1", "x-1"), - "add-2": (add, "add-1", "add-1"), - "add-3": (add, "add-2", "add-2"), - "add-4": (add, "add-3", "add-3"), - "add-5": (add, "add-4", "add-4"), - }, - "add-5", - (), + sol = with_deps( + { + "add-x-1": ( + SubgraphCallable( + { + "x-1": 1, + "add-1": (add, "x-1", "x-1"), + "add-2": (add, "add-1", "add-1"), + "add-3": (add, "add-2", "add-2"), + "add-4": (add, "add-3", "add-3"), + "add-5": (add, "add-4", "add-4"), + }, + "add-5", + (), + ), ), - ), - "add-5": "add-x-1", - }) + "add-5": "add-x-1", + } + ) assert res == sol def test_dont_fuse_numpy_arrays(): - """ - Some types should stay in the graph bare - This helps with things like serialization - """ + """Some types should stay in the graph bare This helps with things like serialization.""" np = pytest.importorskip("numpy") dsk = {"x": np.arange(5), "y": (inc, "x")} diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py index c332eb4860..50cfebdb67 100644 --- a/tests/core/serve/test_dag/test_order.py +++ b/tests/core/serve/test_dag/test_order.py @@ -20,14 +20,14 @@ def f(*args): def test_ordering_keeps_groups_together(abcde): a, b, c, d, e = abcde - d = dict(((a, i), (f, )) for i in range(4)) + d = {(a, i): (f,) for i in range(4)} d.update({(b, 0): (f, (a, 0), (a, 1)), (b, 1): (f, (a, 2), (a, 3))}) o = order(d) assert abs(o[(a, 0)] - o[(a, 1)]) == 1 assert abs(o[(a, 2)] - o[(a, 3)]) == 1 - d = dict(((a, i), (f, )) for i in range(4)) + d = {(a, i): (f,) for i in range(4)} d.update({(b, 0): (f, (a, 0), (a, 2)), (b, 1): (f, (a, 1), (a, 3))}) o = order(d) @@ -46,8 +46,8 @@ def test_avoid_broker_nodes(abcde): """ a, b, c, d, e = abcde dsk = { - (a, 0): (f, ), - (a, 1): (f, ), + (a, 0): (f,), + (a, 1): (f,), (b, 0): (f, (a, 0)), (b, 1): (f, (a, 1)), (b, 2): (f, (a, 1)), @@ -57,8 +57,8 @@ def test_avoid_broker_nodes(abcde): # Switch name of 0, 1 to ensure that this isn't due to string comparison dsk = { - (a, 1): (f, ), - (a, 0): (f, ), + (a, 1): (f,), + (a, 0): (f,), (b, 0): (f, (a, 1)), (b, 1): (f, (a, 0)), (b, 2): (f, (a, 0)), @@ -68,8 +68,8 @@ def test_avoid_broker_nodes(abcde): # Switch name of 0, 1 for "b"s too dsk = { - (a, 0): (f, ), - (a, 1): (f, ), + (a, 0): (f,), + (a, 1): (f,), (b, 1): (f, (a, 0)), (b, 0): (f, (a, 1)), (b, 2): (f, (a, 1)), @@ -161,10 +161,10 @@ def test_avoid_upwards_branching_complex(abcde): (a, 2): (f, (a, 3)), (a, 3): (f, (b, 1), (c, 1)), (b, 1): (f, (b, 2)), - (b, 2): (f, ), + (b, 2): (f,), (c, 1): (f, (c, 2)), (c, 2): (f, (c, 3)), - (c, 3): (f, ), + (c, 3): (f,), (d, 1): (f, (c, 1)), (d, 2): (f, (d, 1)), (d, 3): (f, (d, 1)), @@ -220,7 +220,7 @@ def test_prefer_deep(abcde): def test_stacklimit(abcde): - dsk = dict(("x%s" % (i + 1), (inc, "x%s" % i)) for i in range(10000)) + dsk = {"x%s" % (i + 1): (inc, "x%s" % i) for i in range(10000)} dependencies, dependents = get_deps(dsk) ndependencies(dependencies, dependents) @@ -261,7 +261,7 @@ def test_prefer_short_dependents(abcde): during the long computations. """ a, b, c, d, e = abcde - dsk = {c: (f, ), d: (f, c), e: (f, c), b: (f, c), a: (f, b)} + dsk = {c: (f,), d: (f, c), e: (f, c), b: (f, c), a: (f, b)} o = order(dsk) assert o[d] < o[b] @@ -280,24 +280,23 @@ def test_run_smaller_sections(abcde): Prefer to run acb first because then we can get that out of the way """ a, b, c, d, e = abcde - aa, bb, cc, dd = [x * 2 for x in [a, b, c, d]] + aa, bb, cc, dd = (x * 2 for x in [a, b, c, d]) expected = [a, c, b, e, d, cc, bb, aa, dd] log = [] def f(x): - def _(*args): log.append(x) return _ dsk = { - a: (f(a), ), - c: (f(c), ), - e: (f(e), ), - cc: (f(cc), ), + a: (f(a),), + c: (f(c),), + e: (f(e),), + cc: (f(cc),), b: (f(b), a, c), d: (f(d), c, e), bb: (f(bb), cc), @@ -326,29 +325,28 @@ def test_local_parents_of_reduction(abcde): Prefer to finish a1 stack before proceeding to b2 """ a, b, c, d, e = abcde - a1, a2, a3 = [a + i for i in "123"] - b1, b2, b3 = [b + i for i in "123"] - c1, c2, c3 = [c + i for i in "123"] + a1, a2, a3 = (a + i for i in "123") + b1, b2, b3 = (b + i for i in "123") + c1, c2, c3 = (c + i for i in "123") expected = [a3, a2, a1, b3, b2, b1, c3, c2, c1] log = [] def f(x): - def _(*args): log.append(x) return _ dsk = { - a3: (f(a3), ), + a3: (f(a3),), a2: (f(a2), a3), a1: (f(a1), a2), - b3: (f(b3), ), + b3: (f(b3),), b2: (f(b2), b3, a2), b1: (f(b1), b2), - c3: (f(c3), ), + c3: (f(c3),), c2: (f(c2), c3, b2), c1: (f(c1), c2), } @@ -370,14 +368,14 @@ def test_nearest_neighbor(abcde): This is difficult because all groups are connected. """ a, b, c, _, _ = abcde - a1, a2, a3, a4, a5, a6, a7, a8, a9 = [a + i for i in "123456789"] - b1, b2, b3, b4 = [b + i for i in "1234"] + a1, a2, a3, a4, a5, a6, a7, a8, a9 = (a + i for i in "123456789") + b1, b2, b3, b4 = (b + i for i in "1234") dsk = { - b1: (f, ), - b2: (f, ), - b3: (f, ), - b4: (f, ), + b1: (f,), + b2: (f,), + b3: (f,), + b4: (f,), a1: (f, b1), a2: (f, b1), a3: (f, b1, b2), @@ -397,15 +395,15 @@ def test_nearest_neighbor(abcde): def test_string_ordering(): - """ Prefer ordering tasks by name first """ - dsk = {("a", 1): (f, ), ("a", 2): (f, ), ("a", 3): (f, )} + """Prefer ordering tasks by name first.""" + dsk = {("a", 1): (f,), ("a", 2): (f,), ("a", 3): (f,)} o = order(dsk) assert o == {("a", 1): 0, ("a", 2): 1, ("a", 3): 2} def test_string_ordering_dependents(): - """ Prefer ordering tasks by name first even when in dependencies """ - dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f, )} + """Prefer ordering tasks by name first even when in dependencies.""" + dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f,)} o = order(dsk) assert o == {"b": 0, ("a", 1): 1, ("a", 2): 2, ("a", 3): 3} @@ -502,19 +500,19 @@ def test_map_overlap(abcde): """ a, b, c, d, e = abcde dsk = { - (e, 1): (f, ), + (e, 1): (f,), (d, 1): (f, (e, 1)), (c, 1): (f, (d, 1)), (b, 1): (f, (c, 1), (c, 2)), - (d, 2): (f, ), + (d, 2): (f,), (c, 2): (f, (d, 1), (d, 2), (d, 3)), - (e, 3): (f, ), + (e, 3): (f,), (d, 3): (f, (e, 3)), (c, 3): (f, (d, 3)), (b, 3): (f, (c, 2), (c, 3), (c, 4)), - (d, 4): (f, ), + (d, 4): (f,), (c, 4): (f, (d, 3), (d, 4), (d, 5)), - (e, 5): (f, ), + (e, 5): (f,), (d, 5): (f, (e, 5)), (c, 5): (f, (d, 5)), (b, 5): (f, (c, 4), (c, 5)), @@ -526,22 +524,22 @@ def test_map_overlap(abcde): def test_use_structure_not_keys(abcde): - """See https://github.com/dask/dask/issues/5584#issuecomment-554963958 + """See https://github.com/dask/dask/issues/5584#issuecomment-554963958. We were using key names to infer structure, which could result in funny behavior. """ a, b, _, _, _ = abcde dsk = { - (a, 0): (f, ), - (a, 1): (f, ), - (a, 2): (f, ), - (a, 3): (f, ), - (a, 4): (f, ), - (a, 5): (f, ), - (a, 6): (f, ), - (a, 7): (f, ), - (a, 8): (f, ), - (a, 9): (f, ), + (a, 0): (f,), + (a, 1): (f,), + (a, 2): (f,), + (a, 3): (f,), + (a, 4): (f,), + (a, 5): (f,), + (a, 6): (f,), + (a, 7): (f,), + (a, 8): (f,), + (a, 9): (f,), (b, 5): (f, (a, 2)), (b, 7): (f, (a, 0), (a, 2)), (b, 9): (f, (a, 7), (a, 0), (a, 2)), @@ -566,7 +564,7 @@ def test_use_structure_not_keys(abcde): def test_dont_run_all_dependents_too_early(abcde): - """ From https://github.com/dask/dask-ml/issues/206#issuecomment-395873372 """ + """From https://github.com/dask/dask-ml/issues/206#issuecomment-395873372.""" a, b, c, d, e = abcde depth = 10 dsk = {(a, 0): 0, (b, 0): 1, (c, 0): 2, (d, 0): (f, (a, 0), (b, 0), (c, 0))} @@ -581,13 +579,10 @@ def test_dont_run_all_dependents_too_early(abcde): def test_many_branches_use_ndependencies(abcde): - """From https://github.com/dask/dask/pull/5646#issuecomment-562700533 - - Sometimes we need larger or wider DAGs to test behavior. This test - ensures we choose the branch with more work twice in successtion. - This is important, because ``order`` may search along dependencies - and then along dependents. + """From https://github.com/dask/dask/pull/5646#issuecomment-562700533. + Sometimes we need larger or wider DAGs to test behavior. This test ensures we choose the branch with more work + twice in successtion. This is important, because ``order`` may search along dependencies and then along dependents. """ a, b, c, d, e = abcde dd = d + d @@ -694,32 +689,35 @@ def test_switching_dependents(abcde): def test_order_with_equal_dependents(abcde): - """From https://github.com/dask/dask/issues/5859#issuecomment-608422198 + """From https://github.com/dask/dask/issues/5859#issuecomment-608422198. See the visualization of `(maxima, argmax)` example from the above comment. This DAG has enough structure to exercise more parts of `order` - """ a, b, c, d, e = abcde dsk = {} abc = [a, b, c, d] for x in abc: - dsk.update({ - (x, 0): 0, - (x, 1): (f, (x, 0)), - (x, 2, 0): (f, (x, 0)), - (x, 2, 1): (f, (x, 1)), - }) + dsk.update( + { + (x, 0): 0, + (x, 1): (f, (x, 0)), + (x, 2, 0): (f, (x, 0)), + (x, 2, 1): (f, (x, 1)), + } + ) for i, y in enumerate(abc): - dsk.update({ - (x, 3, i): (f, (x, 2, 0), (y, 2, 1)), # cross x and y - (x, 4, i): (f, (x, 3, i)), - (x, 5, i, 0): (f, (x, 4, i)), - (x, 5, i, 1): (f, (x, 4, i)), - (x, 6, i, 0): (f, (x, 5, i, 0)), - (x, 6, i, 1): (f, (x, 5, i, 1)), - }) + dsk.update( + { + (x, 3, i): (f, (x, 2, 0), (y, 2, 1)), # cross x and y + (x, 4, i): (f, (x, 3, i)), + (x, 5, i, 0): (f, (x, 4, i)), + (x, 5, i, 1): (f, (x, 4, i)), + (x, 6, i, 0): (f, (x, 5, i, 0)), + (x, 6, i, 1): (f, (x, 5, i, 1)), + } + ) o = order(dsk) total = 0 for x in abc: diff --git a/tests/core/serve/test_dag/test_rewrite.py b/tests/core/serve/test_dag/test_rewrite.py index 64055f7211..97fbaf25f3 100644 --- a/tests/core/serve/test_dag/test_rewrite.py +++ b/tests/core/serve/test_dag/test_rewrite.py @@ -21,7 +21,7 @@ def test_head(): def test_args(): - assert args((inc, 1)) == (1, ) + assert args((inc, 1)) == (1,) assert args((add, 1, 2)) == (1, 2) assert args(1) == () assert args([1, 2, 3]) == [1, 2, 3] @@ -65,16 +65,16 @@ def repl_list(sd): return (list, x) -rule6 = RewriteRule((list, "x"), repl_list, ("x", )) +rule6 = RewriteRule((list, "x"), repl_list, ("x",)) def test_RewriteRule(): # Test extraneous vars are removed, varlist is correct - assert rule1.vars == ("a", ) + assert rule1.vars == ("a",) assert rule1._varlist == ["a"] - assert rule2.vars == ("a", ) + assert rule2.vars == ("a",) assert rule2._varlist == ["a", "a"] - assert rule3.vars == ("a", ) + assert rule3.vars == ("a",) assert rule3._varlist == ["a", "a"] assert rule4.vars == ("a", "b") assert rule4._varlist == ["b", "a"] @@ -97,32 +97,13 @@ def test_RuleSet(): { add: ( { - VAR: ({ - VAR: ({}, [1]), - 1: ({}, [0]) - }, []), - inc: ({ - VAR: ({ - inc: ({ - VAR: ({}, [2, 3]) - }, []) - }, []) - }, []), + VAR: ({VAR: ({}, [1]), 1: ({}, [0])}, []), + inc: ({VAR: ({inc: ({VAR: ({}, [2, 3])}, [])}, [])}, []), }, [], ), - list: ({ - VAR: ({}, [5]) - }, []), - sum: ({ - list: ({ - VAR: ({ - VAR: ({ - VAR: ({}, [4]) - }, []) - }, []) - }, []) - }, []), + list: ({VAR: ({}, [5])}, []), + sum: ({list: ({VAR: ({VAR: ({VAR: ({}, [4])}, [])}, [])}, [])}, []), }, [], ) diff --git a/tests/core/serve/test_dag/test_task.py b/tests/core/serve/test_dag/test_task.py index cd7479f5d5..260bc72d0b 100644 --- a/tests/core/serve/test_dag/test_task.py +++ b/tests/core/serve/test_dag/test_task.py @@ -52,7 +52,7 @@ def test_get_dependencies_nested(): def test_get_dependencies_empty(): - dsk = {"x": (inc, )} + dsk = {"x": (inc,)} assert get_dependencies(dsk, "x") == set() assert get_dependencies(dsk, "x", as_list=True) == [] @@ -181,7 +181,6 @@ class MyException(Exception): pass class F: - def __eq__(self, other): raise MyException() @@ -200,9 +199,7 @@ def test_subs_with_surprisingly_friendly_eq(): def test_subs_unexpected_hashable_key(): - class UnexpectedButHashable: - def __init__(self): self.name = "a" diff --git a/tests/core/serve/test_dag/test_utils.py b/tests/core/serve/test_dag/test_utils.py index 17315b5f29..7ce379d006 100644 --- a/tests/core/serve/test_dag/test_utils.py +++ b/tests/core/serve/test_dag/test_utils.py @@ -12,7 +12,6 @@ def test_funcname_long(): - def a_long_function_name_11111111111111111111111111111111111111111111111(): pass @@ -23,7 +22,6 @@ def a_long_function_name_11111111111111111111111111111111111111111111111(): @pytest.mark.skipif(not _CYTOOLZ_AVAILABLE, reason="the library `cytoolz` is not installed.") def test_funcname_cytoolz(): - @curry def foo(a, b, c): pass @@ -45,14 +43,13 @@ def test_partial_by_order(): def test_funcname(): assert funcname(np.floor_divide) == "floor_divide" assert funcname(partial(bool)) == "bool" - assert (funcname(operator.methodcaller("__getitem__")) == "operator.methodcaller('__getitem__')") + assert funcname(operator.methodcaller("__getitem__")) == "operator.methodcaller('__getitem__')" assert funcname(lambda x: x) == "lambda" def test_numpy_vectorize_funcname(): - def myfunc(a, b): - "Return a-b if a>b, otherwise return a+b" + """Return a-b if a>b, otherwise return a+b.""" if a > b: return a - b return a + b diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py index 29c61aa688..17e094dd83 100644 --- a/tests/core/serve/test_gridbase_validations.py +++ b/tests/core/serve/test_gridbase_validations.py @@ -12,7 +12,6 @@ def test_metaclass_raises_if_expose_decorator_not_applied_to_method(): with pytest.raises(SyntaxError, match=r"expose.* decorator"): class FailedNoExposed(ModelComponent): - def __init__(self, model): pass @@ -23,7 +22,6 @@ def test_metaclass_raises_if_more_than_one_expose_decorator_applied(): with pytest.raises(SyntaxError, match=r"decorator must be applied to one"): class FailedTwoExposed(ModelComponent): - def __init__(self, model): pass @@ -44,7 +42,6 @@ def test_metaclass_raises_if_first_arg_in_init_is_not_model(): with pytest.raises(SyntaxError, match="__init__ must set 'model' as first"): class FailedModelArg(ModelComponent): - def __init__(self, foo): pass @@ -60,7 +57,6 @@ def test_metaclass_raises_if_second_arg_is_not_config(): with pytest.raises(SyntaxError, match="__init__ can only set 'config'"): class FailedConfig(ModelComponent): - def __init__(self, model, OTHER): pass @@ -76,7 +72,6 @@ def test_metaclass_raises_if_random_parameters_in_init(): with pytest.raises(SyntaxError, match="__init__ can only have 1 or 2 parameters"): class FailedInit(ModelComponent): - def __init__(self, model, config, FOO): pass @@ -93,7 +88,6 @@ def test_metaclass_raises_uses_restricted_method_name(): with pytest.raises(TypeError, match="bound methods/attrs named"): class FailedMethod_Inputs(ModelComponent): - def __init__(self, model): pass @@ -109,7 +103,6 @@ def inputs(self): with pytest.raises(TypeError, match="bound methods/attrs named"): class FailedMethod_Outputs(ModelComponent): - def __init__(self, model): pass @@ -125,7 +118,6 @@ def outputs(self): with pytest.raises(TypeError, match="bound methods/attrs named"): class FailedMethod_Name(ModelComponent): - def __init__(self, model): pass @@ -136,11 +128,12 @@ def predict(param): @property def uid(self): - return f'{self.uid}_SHOULD_NOT_RETURN' + return f"{self.uid}_SHOULD_NOT_RETURN" # Ensure that if we add more restricted names in the future, # there is a test for them as well. from flash.core.serve.component import _FLASH_SERVE_RESERVED_NAMES + assert set(_FLASH_SERVE_RESERVED_NAMES).difference({"inputs", "outputs", "uid"}) == set() @@ -149,7 +142,6 @@ def test_metaclass_raises_if_argument_values_of_expose_arent_subclasses_of_baset with pytest.raises(TypeError, match="must be subclass of"): class FailedExposedDecoratorInputs(ModelComponent): - def __init__(self, model): self.model = model @@ -162,7 +154,6 @@ def predict(param): with pytest.raises(TypeError, match="must be subclass of"): class FailedExposedDecoratorOutputs(ModelComponent): - def __init__(self, model): self.model = model @@ -175,7 +166,6 @@ def predict(param): with pytest.raises(TypeError, match="must be subclass of"): class FailedExposedDecoratorClass(ModelComponent): - def __init__(self, model): self.model = model @@ -191,13 +181,12 @@ def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_metho ): """This occurs when the instance is being initialized. - This is noted because it differes from some of the other metaclass validations - which will raise an exception at class defiition time. + This is noted because it differes from some of the other metaclass validations which will raise an exception at + class defiition time. """ from tests.core.serve.models import ClassificationInference class FailedExposedDecorator(ModelComponent): - def __init__(self, model): self.model = model @@ -215,12 +204,11 @@ def predict(self, param): def test_ModelComponent_raises_if_config_is_empty_dict(lightning_squeezenet1_1_obj): """This occurs when the instance is being initialized. - This is noted because it differes from some of the other metaclass validations - which will raise an exception at class defiition time. + This is noted because it differes from some of the other metaclass validations which will raise an exception at + class defiition time. """ class ConfigComponent(ModelComponent): - def __init__(self, model, config): pass @@ -236,12 +224,11 @@ def predict(self, param): def test_ModelComponent_raises_if_model_is_empty_iterable(): """This occurs when the instance is being initialized. - This is noted because it differes from some of the other metaclass validations - which will raise an exception at class defiition time. + This is noted because it differes from some of the other metaclass validations which will raise an exception at + class defiition time. """ class ConfigComponent(ModelComponent): - def __init__(self, model): pass diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index 2d3cebef27..4efafb548c 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -89,35 +89,21 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat assert meta.json() == { "definitions": { "Ep_Ep_In_Image": { - "properties": { - "data": { - "title": "Data", - "type": "string" - } - }, + "properties": {"data": {"title": "Data", "type": "string"}}, "required": ["data"], "title": "Ep_Ep_In_Image", "type": "object", }, "Ep_Payload": { - "properties": { - "ep_in_image": { - "$ref": "#/definitions/Ep_Ep_In_Image" - } - }, + "properties": {"ep_in_image": {"$ref": "#/definitions/Ep_Ep_In_Image"}}, "required": ["ep_in_image"], "title": "Ep_Payload", "type": "object", }, }, "properties": { - "payload": { - "$ref": "#/definitions/Ep_Payload" - }, - "session": { - "title": "Session", - "type": "string" - }, + "payload": {"$ref": "#/definitions/Ep_Payload"}, + "session": {"title": "Session", "type": "string"}, }, "required": ["payload"], "title": "Ep_RequestModel", @@ -134,9 +120,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat assert "result" in success.json() expected = { "session": "UUID", - "result": { - "ep_out_prediction": "goldfish, Carassius auratus" - }, + "result": {"ep_out_prediction": "goldfish, Carassius auratus"}, } assert expected == success.json() @@ -209,26 +193,15 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): body = { "session": "UUID", "payload": { - "image": { - "data": imgstr - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "image": {"data": imgstr}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat", json=body) assert success.json() == { - "result": { - "seat_number": 4799680, - "team": "buffalo bills, the ralph" - }, + "result": {"seat_number": 4799680, "team": "buffalo bills, the ralph"}, "session": "UUID", } resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") @@ -295,26 +268,15 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad body = { "session": "UUID", "payload": { - "image": { - "data": imgstr - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "image": {"data": imgstr}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat", json=body) assert success.json() == { - "result": { - "seat_number_out": 4799680, - "team_out": "buffalo bills, the ralph" - }, + "result": {"seat_number_out": 4799680, "team_out": "buffalo bills, the ralph"}, "session": "UUID", } resp = tc.get("http://127.0.0.1:8000/predict_seat/dag") @@ -339,10 +301,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ "section": seat_comp.inputs.section, "row": seat_comp.inputs.row, }, - outputs={ - "seat_number": seat_comp.outputs.seat_number, - "team": seat_comp.outputs.team - }, + outputs={"seat_number": seat_comp.outputs.seat_number, "team": seat_comp.outputs.team}, ) ep2 = Endpoint( route="/predict_seat_img", @@ -366,10 +325,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ "section": seat_comp.inputs.section, "row": seat_comp.inputs.row, }, - outputs={ - "seat_number": seat_comp.outputs.seat_number, - "team": seat_comp.outputs.team - }, + outputs={"seat_number": seat_comp.outputs.seat_number, "team": seat_comp.outputs.team}, ) composit = Composition( @@ -402,26 +358,15 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ body = { "session": "UUID", "payload": { - "image": { - "data": imgstr - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "image": {"data": imgstr}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat", json=body) assert success.json() == { - "result": { - "seat_number": 4799680, - "team": "buffalo bills, the ralph" - }, + "result": {"seat_number": 4799680, "team": "buffalo bills, the ralph"}, "session": "UUID", } @@ -438,26 +383,15 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ body = { "session": "UUID", "payload": { - "stadium": { - "label": "buffalo bills, the ralph" - }, - "section": { - "num": 10 - }, - "isle": { - "num": 4 - }, - "row": { - "num": 53 - }, + "stadium": {"label": "buffalo bills, the ralph"}, + "section": {"num": 10}, + "isle": {"num": 4}, + "row": {"num": 53}, }, } success = tc.post("http://127.0.0.1:8000/predict_seat_img_two", json=body) assert success.json() == { - "result": { - "seat_number": 16960000, - "team": "buffalo bills, the ralph" - }, + "result": {"seat_number": 16960000, "team": "buffalo bills, the ralph"}, "session": "UUID", } @@ -476,6 +410,7 @@ def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1 def test_composition_from_url_torchscript_servable(tmp_path): from flash.core.serve import expose, ModelComponent, Servable from flash.core.serve.types import Number + """ # Tensor x Tensor class MyModule(torch.nn.Module): @@ -494,7 +429,6 @@ def forward(self, a, b): TORCHSCRIPT_DOWNLOAD_URL = "https://github.com/pytorch/pytorch/raw/95489b590f00801bdee7f41783f30874883cf6bb/test/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt" # noqa E501 class ComponentTwoModels(ModelComponent): - def __init__(self, model): self.encoder = model["encoder"] self.decoder = model["decoder"] @@ -523,15 +457,11 @@ def do_my_predict(self, inp): body = { "session": "UUID", "payload": { - "ep_in": { - "num": 10 - }, + "ep_in": {"num": 10}, }, } success = tc.post("http://127.0.0.1:8000/predictr", json=body) assert success.json() == { - "result": { - "ep_out": 1.0 - }, + "result": {"ep_out": 1.0}, "session": "UUID", } diff --git a/tests/core/serve/test_types/test_bbox.py b/tests/core/serve/test_types/test_bbox.py index fb4fbe26c0..ca58a8f2a9 100644 --- a/tests/core/serve/test_types/test_bbox.py +++ b/tests/core/serve/test_types/test_bbox.py @@ -6,7 +6,7 @@ def test_deserialize(): bbox = BBox() - assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4, ))) + assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4,))) assert bbox.deserialize((0, 0, 0, 0)).shape == torch.Size([4]) with pytest.raises(ValueError): # only three elements, need four @@ -19,15 +19,17 @@ def test_deserialize(): bbox.deserialize({1: 1, 2: 2, 3: 3, 4: 4}) with pytest.raises(ValueError): # tuple instead of float - bbox.deserialize(( + bbox.deserialize( ( - 0, - 0, - ), - (0, 0), - (0, 0), - (0, 0), - )) + ( + 0, + 0, + ), + (0, 0), + (0, 0), + (0, 0), + ) + ) def test_serialize(): diff --git a/tests/core/serve/test_types/test_repeated.py b/tests/core/serve/test_types/test_repeated.py index b8fa64ef7e..2038dd29ec 100644 --- a/tests/core/serve/test_types/test_repeated.py +++ b/tests/core/serve/test_types/test_repeated.py @@ -12,11 +12,7 @@ def test_repeated_deserialize(): def test_repeated_serialize(session_global_datadir): repeated = Repeated(dtype=Label(path=str(session_global_datadir / "imagenet_labels.txt"))) - assert repeated.deserialize(*({ - "label": "chickadee" - }, { - "label": "stingray" - })) == ( + assert repeated.deserialize(*({"label": "chickadee"}, {"label": "stingray"})) == ( torch.tensor(19), torch.tensor(6), ) @@ -29,11 +25,7 @@ def test_repeated_max_len(): with pytest.raises(ValueError): repeated.deserialize(*({"label": "classA"}, {"label": "classA"}, {"label": "classB"})) - assert repeated.deserialize(*({ - "label": "classA" - }, { - "label": "classB" - })) == ( + assert repeated.deserialize(*({"label": "classA"}, {"label": "classB"})) == ( torch.tensor(0), torch.tensor(1), ) @@ -52,7 +44,6 @@ def test_repeated_max_len(): def test_repeated_non_serve_dtype(): - class NonServeDtype: pass diff --git a/tests/core/serve/test_types/test_table.py b/tests/core/serve/test_types/test_table.py index c1da29b703..5bccc64892 100644 --- a/tests/core/serve/test_types/test_table.py +++ b/tests/core/serve/test_types/test_table.py @@ -65,14 +65,7 @@ def test_deserialize(): with pytest.raises(RuntimeError): table.deserialize({"title1": {0: 100}, "title2": {0: 200}}) assert torch.allclose( - table.deserialize({ - "t1": { - 0: 100.0 - }, - "t2": { - 1: 200.0 - } - }), + table.deserialize({"t1": {0: 100.0}, "t2": {1: 200.0}}), torch.tensor([[100.0, float("nan")], [float("nan"), 200.0]], dtype=torch.float64), equal_nan=True, ) diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index 9281c36ab4..6cfa7a2c50 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -16,22 +16,22 @@ from flash.core.classification import Classes, FiftyOneLabels, Labels, Logits, Probabilities from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE def test_classification_serializers(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] assert torch.allclose(torch.tensor(Logits().serialize(example_output)), example_output) assert torch.allclose(torch.tensor(Probabilities().serialize(example_output)), torch.softmax(example_output, -1)) assert Classes().serialize(example_output) == 2 - assert Labels(labels).serialize(example_output) == 'class_3' + assert Labels(labels).serialize(example_output) == "class_3" def test_classification_serializers_multi_label(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] assert torch.allclose(torch.tensor(Logits(multi_label=True).serialize(example_output)), example_output) assert torch.allclose( @@ -39,32 +39,33 @@ def test_classification_serializers_multi_label(): torch.sigmoid(example_output), ) assert Classes(multi_label=True).serialize(example_output) == [1, 2] - assert Labels(labels, multi_label=True).serialize(example_output) == ['class_2', 'class_3'] + assert Labels(labels, multi_label=True).serialize(example_output) == ["class_2", "class_3"] +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_classification_serializers_fiftyone(): logits = torch.tensor([-0.1, 0.2, 0.3]) example_output = {DefaultDataKeys.PREDS: logits, DefaultDataKeys.METADATA: {"filepath": "something"}} # 3 classes - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] predictions = FiftyOneLabels(return_filepath=True).serialize(example_output) - assert predictions["predictions"].label == '2' + assert predictions["predictions"].label == "2" assert predictions["filepath"] == "something" predictions = FiftyOneLabels(labels, return_filepath=True).serialize(example_output) - assert predictions["predictions"].label == 'class_3' + assert predictions["predictions"].label == "class_3" assert predictions["filepath"] == "something" predictions = FiftyOneLabels(store_logits=True).serialize(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) assert torch.allclose(torch.tensor(predictions.confidence), torch.softmax(logits, -1)[-1]) - assert predictions.label == '2' + assert predictions.label == "2" predictions = FiftyOneLabels(labels, store_logits=True).serialize(example_output) - assert predictions.label == 'class_3' + assert predictions.label == "class_3" predictions = FiftyOneLabels(store_logits=True, multi_label=True).serialize(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) - assert [c.label for c in predictions.classifications] == ['1', '2'] + assert [c.label for c in predictions.classifications] == ["1", "2"] predictions = FiftyOneLabels(labels, multi_label=True).serialize(example_output) - assert [c.label for c in predictions.classifications] == ['class_2', 'class_3'] + assert [c.label for c in predictions.classifications] == ["class_2", "class_3"] diff --git a/tests/core/test_data.py b/tests/core/test_data.py index a51d8756e2..156669a657 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -21,9 +21,8 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): - return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item() + return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item() def __len__(self) -> int: return 10 @@ -49,7 +48,7 @@ def test_dataloaders(): dm.test_dataloader(), ]: x, y = next(iter(dl)) - assert x.shape == (1, 1, 28, 28) + assert x.shape == (4, 1, 28, 28) def test_cpu_count_none(): diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index ad44cc7dbf..809bfb41ab 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -24,9 +24,8 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index: int) -> Any: - return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1, )).item()} + return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1,)).item()} def __len__(self) -> int: return 100 @@ -34,7 +33,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize( - "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] + "strategy", ["no_freeze", "freeze", "freeze_unfreeze", "unfreeze_milestones", None, "cls", "chocolat"] ) def test_finetuning(tmpdir: str, strategy): train_dl = torch.utils.data.DataLoader(DummyDataset()) @@ -43,7 +42,7 @@ def test_finetuning(tmpdir: str, strategy): trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if strategy == "cls": strategy = NoFreeze() - if strategy == 'chocolat' or strategy is None: + if strategy == "chocolat" or strategy is None: with pytest.raises(MisconfigurationException, match="strategy should be provided"): trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 6336bdfb06..3d3b53b111 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math +from itertools import chain from numbers import Number from pathlib import Path from typing import Any, Tuple @@ -20,15 +22,17 @@ import pytest import pytorch_lightning as pl import torch +from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn, Tensor from torch.nn import functional as F from torch.utils.data import DataLoader import flash +from flash.core.adapter import Adapter from flash.core.classification import ClassificationTask from flash.core.data.process import DefaultPreprocess, Postprocess -from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE +from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING @@ -37,27 +41,24 @@ else: TabularClassifier = None -if _PIL_AVAILABLE: - from PIL import Image -else: - - class Image: - Image = None - # ======== Mock functions ======== class DummyDataset(torch.utils.data.Dataset): + def __init__(self, num_samples: int = 9): + self.num_samples = num_samples def __getitem__(self, index: int) -> Tuple[Tensor, Number]: - return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item() + return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item() def __len__(self) -> int: - return 9 + return self.num_samples class PredictDummyDataset(DummyDataset): + def __init__(self, num_samples: int): + super().__init__(num_samples) def __getitem__(self, index: int) -> Tensor: return torch.rand(1, 28, 28) @@ -68,6 +69,80 @@ class DummyPostprocess(Postprocess): pass +class FixedDataset(torch.utils.data.Dataset): + def __init__(self, targets): + super().__init__() + + self.targets = targets + + def __getitem__(self, index: int) -> Tuple[Tensor, Number]: + return torch.rand(1), self.targets[index] + + def __len__(self) -> int: + return len(self.targets) + + +class OnesModel(nn.Module): + def __init__(self): + super().__init__() + + self.layer = nn.Linear(1, 2) + self.register_buffer("zeros", torch.zeros(2)) + self.register_buffer("zero_one", torch.tensor([0.0, 1.0])) + + def forward(self, x): + x = self.layer(x) + return x * self.zeros + self.zero_one + + +class Parent(ClassificationTask): + def __init__(self, child): + super().__init__() + + self.child = child + + def training_step(self, batch, batch_idx): + return self.child.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.child.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.child.test_step(batch, batch_idx) + + def forward(self, x): + return self.child(x) + + +class GrandParent(Parent): + def __init__(self, child): + super().__init__(Parent(child)) + + +class BasicAdapter(Adapter): + def __init__(self, child): + super().__init__() + + self.child = child + + def training_step(self, batch, batch_idx): + return self.child.training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + return self.child.validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self.child.test_step(batch, batch_idx) + + def forward(self, x): + return self.child(x) + + +class AdapterParent(Parent): + def __init__(self, child): + super().__init__(BasicAdapter(child)) + + # ================================ @@ -83,6 +158,21 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] +@pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent]) +def test_nested_tasks(tmpdir, task): + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + val_dl = torch.utils.data.DataLoader(DummyDataset()) + child_task = ClassificationTask(model, loss_fn=F.nll_loss) + + parent_task = task(child_task) + + trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) + trainer.fit(parent_task, train_dl, val_dl) + result = trainer.test(parent_task, val_dl) + assert "test_nll_loss" in result[0] + + def test_classificationtask_task_predict(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) task = ClassificationTask(model, preprocess=DefaultPreprocess()) @@ -121,15 +211,12 @@ def _rand_image(): def test_classification_task_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) - ds = PredictDummyDataset() - batch_size = 3 - predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size) + ds = PredictDummyDataset(10) + batch_size = 6 + predict_dl = task.process_predict_dataset(ds, batch_size=batch_size) trainer = pl.Trainer(default_root_dir=tmpdir) predictions = trainer.predict(task, predict_dl) - assert len(predictions) == len(ds) // batch_size - for batch_pred in predictions: - assert len(batch_pred) == batch_size - assert all(y < 10 for y in batch_pred) + assert len(list(chain.from_iterable(predictions))) == 10 def test_task_datapipeline_save(tmpdir): @@ -158,24 +245,27 @@ def test_task_datapipeline_save(tmpdir): assert task.postprocess.test -@pytest.mark.parametrize(["cls", "filename"], [ - pytest.param( - ImageClassifier, - "image_classification_model.pt", - marks=pytest.mark.skipif( - not _IMAGE_TESTING, - reason="image packages aren't installed", - ) - ), - pytest.param( - TabularClassifier, - "tabular_classification_model.pt", - marks=pytest.mark.skipif( - not _TABULAR_TESTING, - reason="tabular packages aren't installed", - ) - ), -]) +@pytest.mark.parametrize( + ["cls", "filename"], + [ + pytest.param( + ImageClassifier, + "image_classification_model.pt", + marks=pytest.mark.skipif( + not _IMAGE_TESTING, + reason="image packages aren't installed", + ), + ), + pytest.param( + TabularClassifier, + "tabular_classification_model.pt", + marks=pytest.mark.skipif( + not _TABULAR_TESTING, + reason="tabular packages aren't installed", + ), + ), + ], +) def test_model_download(tmpdir, cls, filename): url = "https://flash-weights.s3.amazonaws.com/" with tmpdir.as_cwd(): @@ -191,7 +281,7 @@ def test_available_backbones(): class Foo(ImageClassifier): backbones = None - assert Foo.available_backbones() == [] + assert Foo.available_backbones() == {} def test_optimization(tmpdir): @@ -212,7 +302,7 @@ def test_optimization(tmpdir): model, optimizer=torch.optim.Adadelta, scheduler=torch.optim.lr_scheduler.StepLR, - scheduler_kwargs={"step_size": 1} + scheduler_kwargs={"step_size": 1}, ) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) @@ -241,11 +331,26 @@ def test_optimization(tmpdir): scheduler_kwargs={"num_warmup_steps": 0.1}, loss_fn=F.nll_loss, ) - trainer = flash.Trainer(max_epochs=1, limit_train_batches=2) + trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count()) ds = DummyDataset() trainer.fit(task, train_dataloader=DataLoader(ds)) optimizer, scheduler = task.configure_optimizers() assert isinstance(optimizer[0], torch.optim.Adadelta) assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR) expected = get_linear_schedule_with_warmup.__name__ - assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected + assert scheduler[0].lr_lambdas[0].__qualname__.split(".")[0] == expected + + +def test_classification_task_metrics(): + train_dataset = FixedDataset([0, 1]) + val_dataset = FixedDataset([1, 1]) + + model = OnesModel() + + class CheckAccuracy(Callback): + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5) + + task = ClassificationTask(model) + trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count()) + trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset)) diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py index 061c6f4504..a230b869c0 100644 --- a/tests/core/test_registry.py +++ b/tests/core/test_registry.py @@ -27,20 +27,20 @@ def test_registry_raises(): def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output - with pytest.raises(MisconfigurationException, match="You can only register a function, found: Linear"): - backbones(nn.Linear(1, 1), name="cho") + with pytest.raises(MisconfigurationException, match="You can only register a callable, found: 3"): + backbones(3, name="foo") - backbones(my_model, name="cho", override=True) + backbones(my_model, name="foo", override=True) - with pytest.raises(MisconfigurationException, match="Function with name: cho and metadata: {}"): - backbones(my_model, name="cho", override=False) + with pytest.raises(MisconfigurationException, match="Function with name: foo and metadata: {}"): + backbones(my_model, name="foo", override=False) with pytest.raises(KeyError, match="Found no matches"): - backbones.get("cho", foo="bar") + backbones.get("foo", baz="bar") - backbones.remove("cho") - with pytest.raises(KeyError, match="Key: cho is not in FlashRegistry"): - backbones.get("cho") + backbones.remove("foo") + with pytest.raises(KeyError, match="Key: foo is not in FlashRegistry"): + backbones.get("foo") with pytest.raises(TypeError, match="name` must be a str"): backbones(name=float) # noqa @@ -59,30 +59,30 @@ def my_model(nc_input=5, nc_output=6): assert mlp.weight.shape == (7, 5) # basic get - backbones(my_model, name="cho") - assert backbones.get("cho") + backbones(my_model, name="foo") + assert backbones.get("foo") # test override - backbones(my_model, name="cho", override=True) - functions = backbones.get("cho", strict=False) + backbones(my_model, name="foo", override=True) + functions = backbones.get("foo", strict=False) assert len(functions) == 1 # test metadata filtering - backbones(my_model, name="cho", namespace="timm", type="resnet") - backbones(my_model, name="cho", namespace="torchvision", type="resnet") - backbones(my_model, name="cho", namespace="timm", type="densenet") - backbones(my_model, name="cho", namespace="timm", type="alexnet") - function = backbones.get("cho", with_metadata=True, type="resnet", namespace="timm") - assert function["name"] == "cho" + backbones(my_model, name="foo", namespace="timm", type="resnet") + backbones(my_model, name="foo", namespace="torchvision", type="resnet") + backbones(my_model, name="foo", namespace="timm", type="densenet") + backbones(my_model, name="foo", namespace="timm", type="alexnet") + function = backbones.get("foo", with_metadata=True, type="resnet", namespace="timm") + assert function["name"] == "foo" assert function["metadata"] == {"namespace": "timm", "type": "resnet"} # test strict=False and with_metadata=False - functions = backbones.get("cho", namespace="timm", strict=False) + functions = backbones.get("foo", namespace="timm", strict=False) assert len(functions) == 3 assert all(callable(f) for f in functions) # test available keys - assert backbones.available_keys() == ['cho', 'cho', 'cho', 'cho', 'cho', 'my_model'] + assert backbones.available_keys() == ["foo", "foo", "foo", "foo", "foo", "my_model"] # todo (tchaton) Debug this test. @@ -100,8 +100,8 @@ def my_model(): assert caplog.messages == [ "Registering: my_model function with name: bar and metadata: {'foobar': True}", - 'Registering: my_model function with name: foo and metadata: {}', - 'Registering: my_model function with name: my_model and metadata: {}' + "Registering: my_model function with name: foo and metadata: {}", + "Registering: my_model function with name: my_model and metadata: {}", ] assert len(backbones) == 3 diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index 7bd330d83a..436bb48a2e 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -27,7 +27,6 @@ class DummyDataset(torch.utils.data.Dataset): - def __init__(self, predict: bool = False): self._predict = predict @@ -35,14 +34,13 @@ def __getitem__(self, index: int) -> Any: sample = torch.rand(1, 28, 28) if self._predict: return sample - return sample, torch.randint(10, size=(1, )).item() + return sample, torch.randint(10, size=(1,)).item() def __len__(self) -> int: return 100 class DummyClassifier(nn.Module): - def __init__(self): super().__init__() self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) @@ -85,7 +83,6 @@ def test_resolve_callbacks_invalid_strategy(tmpdir): class MultiFinetuneClassificationTask(ClassificationTask): - def configure_finetune_callback(self): return [NoFreeze(), NoFreeze()] @@ -99,7 +96,6 @@ def test_resolve_callbacks_multi_error(tmpdir): class FinetuneClassificationTask(ClassificationTask): - def configure_finetune_callback(self): return [NoFreeze()] @@ -115,14 +111,14 @@ def test_resolve_callbacks_override_warning(tmpdir): def test_add_argparse_args(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) - args = parser.parse_args(['--gpus=1']) + args = parser.parse_args(["--gpus=1"]) assert args.gpus == 1 def test_from_argparse_args(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) - args = parser.parse_args(['--max_epochs=200']) + args = parser.parse_args(["--max_epochs=200"]) trainer = Trainer.from_argparse_args(args) assert trainer.max_epochs == 200 assert isinstance(trainer, Trainer) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 250aba1122..49d24bf7ab 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -20,7 +20,6 @@ class A: - def __call__(self, x): return True @@ -54,4 +53,4 @@ def test_get_callable_dict(): def test_download_data(tmpdir): path = os.path.join(tmpdir, "data") download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", path) - assert set(os.listdir(path)) == {'titanic', 'titanic.zip'} + assert set(os.listdir(path)) == {"titanic", "titanic.zip"} diff --git a/tests/core/utilities/__init__.py b/tests/core/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py new file mode 100644 index 0000000000..1b664a02e5 --- /dev/null +++ b/tests/core/utilities/test_lightning_cli.py @@ -0,0 +1,721 @@ +# Adapted from the Lightning CLI: +# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/utilities/test_cli.py +import inspect +import json +import os +import pickle +import sys +from argparse import Namespace +from contextlib import redirect_stdout +from io import StringIO +from typing import List, Optional, Union +from unittest import mock + +import pytest +import torch +import yaml +from packaging import version +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.plugins.environments import SLURMEnvironment + +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.utilities.lightning_cli import ( + instantiate_class, + LightningArgumentParser, + LightningCLI, + SaveConfigCallback, +) +from tests.helpers.boring_model import BoringDataModule, BoringModel + +torchvision_version = version.parse("0") +if _TORCHVISION_AVAILABLE: + torchvision_version = version.parse(__import__("torchvision").__version__) + + +@mock.patch("argparse.ArgumentParser.parse_args") +def test_default_args(mock_argparse, tmpdir): + """Tests default argument parser for Trainer.""" + mock_argparse.return_value = Namespace(**Trainer.default_attributes()) + + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + args = parser.parse_args([]) + + args.max_epochs = 5 + trainer = Trainer.from_argparse_args(args) + + assert isinstance(trainer, Trainer) + assert trainer.max_epochs == 5 + + +@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) +def test_add_argparse_args_redefined(cli_args): + """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + + args = parser.parse_args(cli_args) + + # make sure we can pickle args + pickle.dumps(args) + + # Check few deprecated args are not in namespace: + for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"): + assert depr_name not in args + + trainer = Trainer.from_argparse_args(args=args) + pickle.dumps(trainer) + + assert isinstance(trainer, Trainer) + + +@pytest.mark.parametrize( + ["cli_args", "expected"], + [ + ("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")), + ( + "--auto_lr_find any_string --auto_scale_batch_size ON", + dict(auto_lr_find="any_string", auto_scale_batch_size=True), + ), + ("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)), + ("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)), + ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)), + ("--limit_train_batches=100", dict(limit_train_batches=100)), + ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)), + ("--weights_summary=null", dict(weights_summary=None)), + ( + "", + dict( + # These parameters are marked as Optional[...] in Trainer.__init__, + # with None as default. They should not be changed by the argparse + # interface. + min_steps=None, + max_steps=None, + log_gpu_memory=None, + distributed_backend=None, + weights_save_path=None, + truncated_bptt_steps=None, + resume_from_checkpoint=None, + profiler=None, + ), + ), + ], +) +def test_parse_args_parsing(cli_args, expected): + """Test parsing simple types and None optionals not modified.""" + cli_args = cli_args.split(" ") if cli_args else [] + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + for k, v in expected.items(): + assert getattr(args, k) == v + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize( + ["cli_args", "expected", "instantiate"], + [ + (["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False), + (["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False), + (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True), + ], +) +def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): + """Test parsing complex types.""" + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + for k, v in expected.items(): + assert getattr(args, k) == v + if instantiate: + assert Trainer.from_argparse_args(args) + + +@pytest.mark.parametrize( + ["cli_args", "expected_gpu"], + [ + ("--gpus 1", [0]), + ("--gpus 0,", [0]), + ("--gpus 0,1", [0, 1]), + ], +) +def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): + """Test parsing of gpus and instantiation of Trainer.""" + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + cli_args = cli_args.split(" ") if cli_args else [] + parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser.add_lightning_class_args(Trainer, None) + with mock.patch("sys.argv", ["any.py"] + cli_args): + args = parser.parse_args() + + trainer = Trainer.from_argparse_args(args) + assert trainer.data_parallel_device_ids == expected_gpu + + +@pytest.mark.skipif( + sys.version_info < (3, 7), + reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", +) +@pytest.mark.parametrize( + ["cli_args", "extra_args"], + [ + ({}, {}), + (dict(logger=False), {}), + (dict(logger=False), dict(logger=True)), + (dict(logger=False), dict(checkpoint_callback=True)), + ], +) +def test_init_from_argparse_args(cli_args, extra_args): + unknown_args = dict(unknown_arg=0) + + # unkown args in the argparser/namespace should be ignored + with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init: + trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) + expected = dict(cli_args) + expected.update(extra_args) # extra args should override any cli arg + init.assert_called_with(trainer, **expected) + + # passing in unknown manual args should throw an error + with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"): + Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) + + +class Model(LightningModule): + def __init__(self, model_param: int): + super().__init__() + self.model_param = model_param + + +def model_builder(model_param: int) -> Model: + return Model(model_param) + + +def trainer_builder( + limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None +) -> Trainer: + return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) + + +@pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (trainer_builder, model_builder)]) +def test_lightning_cli(trainer_class, model_class, monkeypatch): + """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" + + expected_model = dict(model_param=7) + expected_trainer = dict(limit_train_batches=100) + + def fit(trainer, model): + for k, v in expected_model.items(): + assert getattr(model, k) == v + for k, v in expected_trainer.items(): + assert getattr(trainer, k) == v + save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)] + assert len(save_callback) == 1 + save_callback[0].on_train_start(trainer, model) + + def on_train_start(callback, trainer, _): + config_dump = callback.parser.dump(callback.config, skip_none=False) + for k, v in expected_model.items(): + assert f" {k}: {v}" in config_dump + for k, v in expected_trainer.items(): + assert f" {k}: {v}" in config_dump + trainer.ran_asserts = True + + monkeypatch.setattr(Trainer, "fit", fit) + monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start) + + with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]): + cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback) + assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts + + +def test_lightning_cli_args_callbacks(tmpdir): + + callbacks = [ + dict( + class_path="pytorch_lightning.callbacks.LearningRateMonitor", + init_args=dict(logging_interval="epoch", log_momentum=True), + ), + dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")), + ] + + class TestModel(BoringModel): + def on_fit_start(self): + callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == "epoch" + assert callback[0].log_momentum is True + callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + assert len(callback) == 1 + assert callback[0].monitor == "NAME" + self.trainer.ran_asserts = True + + with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]): + cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + + assert cli.trainer.ran_asserts + + +def test_lightning_cli_configurable_callbacks(tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--learning_rate_monitor.logging_interval=epoch", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(BoringModel) + + callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == "epoch" + + +def test_lightning_cli_args_cluster_environments(tmpdir): + plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] + + class TestModel(BoringModel): + def on_fit_start(self): + # Ensure SLURMEnvironment is set, instead of default LightningEnvironment + assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment) + self.trainer.ran_asserts = True + + with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): + cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + + assert cli.trainer.ran_asserts + + +def test_lightning_cli_args(tmpdir): + + cli_args = [ + f"--data.data_dir={tmpdir}", + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--trainer.weights_summary=null", + "--seed_everything=1234", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]}) + + assert cli.config["seed_everything"] == 1234 + config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert "model" not in config and "model" not in cli.config # no arguments to include + assert config["data"] == cli.config["data"] + assert config["trainer"] == cli.config["trainer"] + + +def test_lightning_cli_save_config_cases(tmpdir): + + config_path = tmpdir / "config.yaml" + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.logger=False", + "--trainer.fast_dev_run=1", + ] + + # With fast_dev_run!=False config should not be saved + with mock.patch("sys.argv", ["any.py"] + cli_args): + LightningCLI(BoringModel) + assert not os.path.isfile(config_path) + + # With fast_dev_run==False config should be saved + cli_args[-1] = "--trainer.max_epochs=1" + with mock.patch("sys.argv", ["any.py"] + cli_args): + LightningCLI(BoringModel) + assert os.path.isfile(config_path) + + # If run again on same directory exception should be raised since config file already exists + with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError): + LightningCLI(BoringModel) + + +def test_lightning_cli_config_and_subclass_mode(tmpdir): + + config = dict( + model=dict(class_path="tests.helpers.boring_model.BoringModel"), + data=dict(class_path="tests.helpers.boring_model.BoringDataModule", init_args=dict(data_dir=str(tmpdir))), + trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None), + ) + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: + f.write(yaml.dump(config)) + + with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]): + cli = LightningCLI( + BoringModel, + BoringDataModule, + subclass_mode_model=True, + subclass_mode_data=True, + trainer_defaults={"callbacks": LearningRateMonitor()}, + ) + + config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" + assert os.path.isfile(config_path) + with open(config_path) as f: + config = yaml.safe_load(f.read()) + assert config["model"] == cli.config["model"] + assert config["data"] == cli.config["data"] + assert config["trainer"] == cli.config["trainer"] + + +def any_model_any_data_cli(): + LightningCLI( + LightningModule, + LightningDataModule, + subclass_mode_model=True, + subclass_mode_data=True, + ) + + +def test_lightning_cli_help(): + + cli_args = ["any.py", "--help"] + out = StringIO() + with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + any_model_any_data_cli() + + assert "--print_config" in out.getvalue() + assert "--config" in out.getvalue() + assert "--seed_everything" in out.getvalue() + assert "--model.help" in out.getvalue() + assert "--data.help" in out.getvalue() + + skip_params = {"self"} + for param in inspect.signature(Trainer.__init__).parameters.keys(): + if param not in skip_params: + assert f"--trainer.{param}" in out.getvalue() + + cli_args = ["any.py", "--data.help=tests.helpers.boring_model.BoringDataModule"] + out = StringIO() + with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + any_model_any_data_cli() + + assert "--data.init_args.data_dir" in out.getvalue() + + +def test_lightning_cli_print_config(): + + cli_args = [ + "any.py", + "--seed_everything=1234", + "--model=tests.helpers.boring_model.BoringModel", + "--data=tests.helpers.boring_model.BoringDataModule", + "--print_config", + ] + + out = StringIO() + with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + any_model_any_data_cli() + + outval = yaml.safe_load(out.getvalue()) + assert outval["seed_everything"] == 1234 + assert outval["model"]["class_path"] == "tests.helpers.boring_model.BoringModel" + assert outval["data"]["class_path"] == "tests.helpers.boring_model.BoringDataModule" + + +def test_lightning_cli_submodules(tmpdir): + class MainModule(BoringModel): + def __init__( + self, + submodule1: LightningModule, + submodule2: LightningModule, + main_param: int = 1, + ): + super().__init__() + self.submodule1 = submodule1 + self.submodule2 = submodule2 + + config = """model: + main_param: 2 + submodule1: + class_path: tests.helpers.boring_model.BoringModel + submodule2: + class_path: tests.helpers.boring_model.BoringModel + """ + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: + f.write(config) + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + f"--config={str(config_path)}", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(MainModule) + + assert cli.config["model"]["main_param"] == 2 + assert isinstance(cli.model.submodule1, BoringModel) + assert isinstance(cli.model.submodule2, BoringModel) + + +@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required") +def test_lightning_cli_torch_modules(tmpdir): + class TestModule(BoringModel): + def __init__( + self, + activation: torch.nn.Module = None, + transform: Optional[List[torch.nn.Module]] = None, + ): + super().__init__() + self.activation = activation + self.transform = transform + + config = """model: + activation: + class_path: torch.nn.LeakyReLU + init_args: + negative_slope: 0.2 + transform: + - class_path: torchvision.transforms.Resize + init_args: + size: 64 + - class_path: torchvision.transforms.CenterCrop + init_args: + size: 64 + """ + config_path = tmpdir / "config.yaml" + with open(config_path, "w") as f: + f.write(config) + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + f"--config={str(config_path)}", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(TestModule) + + assert isinstance(cli.model.activation, torch.nn.LeakyReLU) + assert cli.model.activation.negative_slope == 0.2 + assert len(cli.model.transform) == 2 + assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) + + +class BoringModelRequiredClasses(BoringModel): + def __init__( + self, + num_classes: int, + batch_size: int = 8, + ): + super().__init__() + self.num_classes = num_classes + self.batch_size = batch_size + + +class BoringDataModuleBatchSizeAndClasses(BoringDataModule): + def __init__( + self, + batch_size: int = 8, + ): + super().__init__() + self.batch_size = batch_size + self.num_classes = 5 # only available after instantiation + + +def test_lightning_cli_link_arguments(tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("data.batch_size", "model.batch_size") + parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate") + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--data.batch_size=12", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses) + + assert cli.model.batch_size == 12 + assert cli.model.num_classes == 5 + + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("data.batch_size", "model.init_args.batch_size") + parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate") + + cli_args[-1] = "--model=tests.core.utilities.test_lightning_cli.BoringModelRequiredClasses" + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI( + BoringModelRequiredClasses, + BoringDataModuleBatchSizeAndClasses, + subclass_mode_model=True, + ) + + assert cli.model.batch_size == 8 + assert cli.model.num_classes == 5 + + +class EarlyExitTestModel(BoringModel): + def on_fit_start(self): + raise KeyboardInterrupt() + + +@pytest.mark.parametrize("logger", (False, True)) +@pytest.mark.parametrize( + "trainer_kwargs", + ( + dict(accelerator="ddp_cpu"), + dict(accelerator="ddp_cpu", plugins="ddp_find_unused_parameters_false"), + ), +) +def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): + with mock.patch("sys.argv", ["any.py"]), pytest.raises(KeyboardInterrupt): + LightningCLI( + EarlyExitTestModel, + trainer_defaults={ + "default_root_dir": str(tmpdir), + "logger": logger, + "max_steps": 1, + "max_epochs": 1, + **trainer_kwargs, + }, + ) + if logger: + config_dir = tmpdir / "lightning_logs" + # no more version dirs should get created + assert os.listdir(config_dir) == ["version_0"] + config_path = config_dir / "version_0" / "config.yaml" + else: + config_path = tmpdir / "config.yaml" + assert os.path.isfile(config_path) + + +def test_cli_config_overwrite(tmpdir): + trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} + + with mock.patch("sys.argv", ["any.py"]): + LightningCLI(BoringModel, trainer_defaults=trainer_defaults) + with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"): + LightningCLI(BoringModel, trainer_defaults=trainer_defaults) + with mock.patch("sys.argv", ["any.py"]): + LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults) + + +def test_lightning_cli_optimizer(tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args(torch.optim.Adam) + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + ] + + match = ( + "BoringModel.configure_optimizers` will be overridden by " + "`MyLightningCLI.add_configure_optimizers_method_to_model`" + ) + with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.warns(UserWarning, match=match): + cli = MyLightningCLI(BoringModel) + + assert cli.model.configure_optimizers is not BoringModel.configure_optimizers + assert len(cli.trainer.optimizers) == 1 + assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) + assert len(cli.trainer.lr_schedulers) == 0 + + +def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args(torch.optim.Adam) + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--lr_scheduler.gamma=0.8", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(BoringModel) + + assert cli.model.configure_optimizers is not BoringModel.configure_optimizers + assert len(cli.trainer.optimizers) == 1 + assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) + assert len(cli.trainer.lr_schedulers) == 1 + assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR) + assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 + + +def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) + parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR)) + + optimizer_arg = dict( + class_path="torch.optim.Adam", + init_args=dict(lr=0.01), + ) + lr_scheduler_arg = dict( + class_path="torch.optim.lr_scheduler.StepLR", + init_args=dict(step_size=50), + ) + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + f"--optimizer={json.dumps(optimizer_arg)}", + f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(BoringModel) + + assert len(cli.trainer.optimizers) == 1 + assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) + assert len(cli.trainer.lr_schedulers) == 1 + assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) + assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 + + +def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") + parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") + + class TestModel(BoringModel): + def __init__( + self, + optim1: dict, + optim2: dict, + scheduler: dict, + ): + super().__init__() + self.optim1 = instantiate_class(self.parameters(), optim1) + self.optim2 = instantiate_class(self.parameters(), optim2) + self.scheduler = instantiate_class(self.optim1, scheduler) + + cli_args = [ + f"--trainer.default_root_dir={tmpdir}", + "--trainer.max_epochs=1", + "--optim2.class_path=torch.optim.SGD", + "--optim2.init_args.lr=0.01", + "--lr_scheduler.gamma=0.2", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = MyLightningCLI(TestModel) + + assert isinstance(cli.model.optim1, torch.optim.Adam) + assert isinstance(cli.model.optim2, torch.optim.SGD) + assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py index 3ba73cc309..5fe061c678 100644 --- a/tests/examples/test_integrations.py +++ b/tests/examples/test_integrations.py @@ -17,21 +17,24 @@ import pytest +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from tests.examples.utils import run_test -from tests.helpers.utils import _IMAGE_TESTING root = Path(__file__).parent.parent.parent @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "folder, file", [ + "folder, file", + [ pytest.param( "fiftyone", "image_classification.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="fiftyone library isn't installed") + marks=pytest.mark.skipif( + not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" + ), ), - ] + ], ) def test_integrations(tmpdir, folder, file): run_test(str(root / "flash_examples" / "integrations" / folder / file)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 1decf2943b..75a5d7cd5f 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -20,7 +20,15 @@ import flash from flash.core.utilities.imports import _SKLEARN_AVAILABLE from tests.examples.utils import run_test -from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING, _VIDEO_TESTING +from tests.helpers.utils import ( + _AUDIO_TESTING, + _GRAPH_TESTING, + _IMAGE_TESTING, + _POINTCLOUD_TESTING, + _TABULAR_TESTING, + _TEXT_TESTING, + _VIDEO_TESTING, +) @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @@ -30,47 +38,80 @@ pytest.param( "custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") ), + pytest.param( + "audio_classification.py", + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), + ), + pytest.param( + "speech_recognition.py", + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), + ), pytest.param( "image_classification.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), pytest.param( "image_classification_multi_label.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), # pytest.param("finetuning", "object_detection.py"), # TODO: takes too long. pytest.param( "semantic_segmentation.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), ), pytest.param( - "style_transfer.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") + "style_transfer.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") ), pytest.param( "summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), pytest.param( "tabular_classification.py", - marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed") + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), ), pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")), pytest.param( "text_classification.py", - marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") - ), - pytest.param( - "text_classification_multi_label.py", - marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), ), + # pytest.param( + # "text_classification_multi_label.py", + # marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") + # ), pytest.param( "translation.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") ), pytest.param( "video_classification.py", - marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed") + marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed"), + ), + pytest.param( + "pointcloud_segmentation.py", + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), ), - ] + pytest.param( + "pointcloud_detection.py", + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), + ), + pytest.param( + "graph_classification.py", + marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"), + ), + ], ) def test_example(tmpdir, file): run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file)) + + +@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) +@pytest.mark.parametrize( + "file", + [ + pytest.param( + "pointcloud_detection.py", + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), + ), + ], +) +def test_example_2(tmpdir, file): + run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file)) diff --git a/tests/examples/utils.py b/tests/examples/utils.py index aeeacacd0d..cf713fcbd1 100644 --- a/tests/examples/utils.py +++ b/tests/examples/utils.py @@ -19,12 +19,12 @@ def call_script( filepath: str, args: Optional[List[str]] = None, - timeout: Optional[int] = 60 * 5, + timeout: Optional[int] = 60 * 10, ) -> Tuple[int, str, str]: - with open(filepath, 'r') as original: + with open(filepath) as original: data = original.read() - with open(filepath, 'w') as modified: + with open(filepath, "w") as modified: modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data) if args is None: @@ -41,7 +41,7 @@ def call_script( stdout = stdout.decode("utf-8") stderr = stderr.decode("utf-8") - with open(filepath, 'w') as modified: + with open(filepath, "w") as modified: modified.write(data) return p.returncode, stdout, stderr diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/graph/classification/__init__.py b/tests/graph/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py new file mode 100644 index 0000000000..de4d08ff72 --- /dev/null +++ b/tests/graph/classification/test_data.py @@ -0,0 +1,132 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from flash.core.data.transforms import merge_transforms +from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE +from flash.graph.classification.data import GraphClassificationData, GraphClassificationPreprocess +from tests.helpers.utils import _GRAPH_TESTING + +if _TORCH_GEOMETRIC_AVAILABLE: + from torch_geometric.datasets import TUDataset + from torch_geometric.transforms import OneHotDegree + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed.") +class TestGraphClassificationPreprocess: + """Tests ``GraphClassificationPreprocess``.""" + + def test_smoke(self): + """A simple test that the class can be instantiated.""" + prep = GraphClassificationPreprocess() + assert prep is not None + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed.") +class TestGraphClassificationData: + """Tests ``GraphClassificationData``.""" + + def test_smoke(self): + dm = GraphClassificationData() + assert dm is not None + + def test_from_datasets(self, tmpdir): + tudataset = TUDataset(root=tmpdir, name="KKI") + train_dataset = tudataset + val_dataset = tudataset + test_dataset = tudataset + predict_dataset = tudataset + + # instantiate the data module + dm = GraphClassificationData.from_datasets( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + predict_dataset=predict_dataset, + train_transform=None, + val_transform=None, + test_transform=None, + predict_transform=None, + batch_size=2, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + assert list(data.x.size())[1] == tudataset.num_features + assert list(data.y.size()) == [2] + + # check val data + data = next(iter(dm.val_dataloader())) + assert list(data.x.size())[1] == tudataset.num_features + assert list(data.y.size()) == [2] + + # check test data + data = next(iter(dm.test_dataloader())) + assert list(data.x.size())[1] == tudataset.num_features + assert list(data.y.size()) == [2] + + def test_transforms(self, tmpdir): + tudataset = TUDataset(root=tmpdir, name="KKI") + train_dataset = tudataset + val_dataset = tudataset + test_dataset = tudataset + predict_dataset = tudataset + + # instantiate the data module + dm = GraphClassificationData.from_datasets( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + predict_dataset=predict_dataset, + train_transform=merge_transforms( + GraphClassificationPreprocess.default_transforms(), + {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + ), + val_transform=merge_transforms( + GraphClassificationPreprocess.default_transforms(), + {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + ), + test_transform=merge_transforms( + GraphClassificationPreprocess.default_transforms(), + {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + ), + predict_transform=merge_transforms( + GraphClassificationPreprocess.default_transforms(), + {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + ), + batch_size=2, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + assert list(data.x.size())[1] == tudataset.num_features * 2 + assert list(data.y.size()) == [2] + + # check val data + data = next(iter(dm.val_dataloader())) + assert list(data.x.size())[1] == tudataset.num_features * 2 + assert list(data.y.size()) == [2] + + # check test data + data = next(iter(dm.test_dataloader())) + assert list(data.x.size())[1] == tudataset.num_features * 2 + assert list(data.y.size()) == [2] diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py new file mode 100644 index 0000000000..0813a6fb3a --- /dev/null +++ b/tests/graph/classification/test_model.py @@ -0,0 +1,88 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest +import torch + +from flash import Trainer +from flash.__main__ import main +from flash.core.data.data_pipeline import DataPipeline +from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE +from flash.graph.classification import GraphClassifier +from flash.graph.classification.data import GraphClassificationPreprocess +from tests.helpers.utils import _GRAPH_TESTING + +if _TORCH_GEOMETRIC_AVAILABLE: + from torch_geometric import datasets + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_smoke(): + """A simple test that the class can be instantiated.""" + model = GraphClassifier(num_features=1, num_classes=1) + assert model is not None + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_train(tmpdir): + """Tests that the model can be trained on a pytorch geometric dataset.""" + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") + model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) + model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) + train_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_val(tmpdir): + """Tests that the model can be validated on a pytorch geometric dataset.""" + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") + model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) + model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) + val_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.validate(model, val_dl) + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_test(tmpdir): + """Tests that the model can be tested on a pytorch geometric dataset.""" + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") + model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) + model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess()) + test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.test(model, test_dl) + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_predict_dataset(tmpdir): + """Tests that we can generate predictions from a pytorch geometric dataset.""" + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") + model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes) + data_pipe = DataPipeline(preprocess=GraphClassificationPreprocess()) + out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) + assert isinstance(out[0], int) + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_cli(): + cli_args = ["flash", "graph_classification", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py new file mode 100644 index 0000000000..e7ece2c0b8 --- /dev/null +++ b/tests/helpers/boring_model.py @@ -0,0 +1,135 @@ +# Adapted from: +# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/boring_model.py +from typing import Optional + +import torch +from pytorch_lightning import LightningDataModule, LightningModule +from torch.utils.data import DataLoader, Dataset, Subset + + +class RandomDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + def __init__(self): + """Testing PL Module. + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + """ + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def loss(self, batch, prediction): + # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls + return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) + + def step(self, x): + x = self(x) + out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) + return out + + def training_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x["x"] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + def predict_dataloader(self): + return DataLoader(RandomDataset(32, 64)) + + +class BoringDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): + super().__init__() + self.data_dir = data_dir + self.non_picklable = None + self.checkpoint_state: Optional[str] = None + + def prepare_data(self): + self.random_full = RandomDataset(32, 64 * 4) + + def setup(self, stage: Optional[str] = None): + if stage == "fit" or stage is None: + self.random_train = Subset(self.random_full, indices=range(64)) + self.dims = self.random_train[0].shape + + if stage in ("fit", "validate") or stage is None: + self.random_val = Subset(self.random_full, indices=range(64, 64 * 2)) + + if stage == "test" or stage is None: + self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3)) + self.dims = getattr(self, "dims", self.random_test[0].shape) + + if stage == "predict" or stage is None: + self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4)) + self.dims = getattr(self, "dims", self.random_predict[0].shape) + + def train_dataloader(self): + return DataLoader(self.random_train) + + def val_dataloader(self): + return DataLoader(self.random_val) + + def test_dataloader(self): + return DataLoader(self.random_test) + + def predict_dataloader(self): + return DataLoader(self.random_predict) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 0fa1815db8..bd57cf570d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -14,7 +14,10 @@ import os from flash.core.utilities.imports import ( + _AUDIO_AVAILABLE, + _GRAPH_AVAILABLE, _IMAGE_AVAILABLE, + _POINTCLOUD_AVAILABLE, _SERVE_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE, @@ -26,6 +29,9 @@ _TABULAR_TESTING = _TABULAR_AVAILABLE _TEXT_TESTING = _TEXT_AVAILABLE _SERVE_TESTING = _SERVE_AVAILABLE +_POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE +_GRAPH_TESTING = _GRAPH_AVAILABLE +_AUDIO_TESTING = _AUDIO_AVAILABLE if "FLASH_TEST_TOPIC" in os.environ: topic = os.environ["FLASH_TEST_TOPIC"] @@ -34,3 +40,6 @@ _TABULAR_TESTING = topic == "tabular" _TEXT_TESTING = topic == "text" _SERVE_TESTING = topic == "serve" + _POINTCLOUD_TESTING = topic == "pointcloud" + _GRAPH_TESTING = topic == "graph" + _AUDIO_TESTING = topic == "audio" diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 183f3427a4..99bf240646 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import csv from pathlib import Path from typing import Any, List, Tuple @@ -21,7 +22,13 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import ( + _FIFTYONE_AVAILABLE, + _IMAGE_AVAILABLE, + _MATPLOTLIB_AVAILABLE, + _PIL_AVAILABLE, + _TORCHVISION_AVAILABLE, +) from flash.image import ImageClassificationData from tests.helpers.utils import _IMAGE_TESTING @@ -72,9 +79,9 @@ def test_from_filepaths_smoke(tmpdir): assert img_data.test_dataloader() is None data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [1, 2] @@ -104,28 +111,29 @@ def test_from_filepaths_list_image_paths(tmpdir): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) @@ -160,7 +168,8 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -207,7 +216,7 @@ def test_from_filepaths_splits(tmpdir): _rand_image(img_size).save(tmpdir / "s.png") num_samples: int = 10 - val_split: float = .3 + val_split: float = 0.3 train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] @@ -218,7 +227,7 @@ def test_from_filepaths_splits(tmpdir): _to_tensor = { "to_tensor_transform": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), ), } @@ -234,9 +243,9 @@ def run(transform: Any = None): image_size=img_size, ) data = next(iter(dm.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (B, 3, H, W) - assert labels.shape == (B, ) + assert labels.shape == (B,) run(_to_tensor) @@ -257,9 +266,9 @@ def test_from_folders_only_train(tmpdir): img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, ) + assert labels.shape == (1,) assert img_data.val_dataloader() is None assert img_data.test_dataloader() is None @@ -287,20 +296,20 @@ def test_from_folders_train_val(tmpdir): ) data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [0, 0] @@ -329,18 +338,18 @@ def test_from_filepaths_multilabel(tmpdir): ) data = next(iter(dm.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) data = next(iter(dm.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) torch.testing.assert_allclose(labels, torch.tensor(test_labels)) @@ -368,28 +377,28 @@ def test_from_data(data, from_function): # check training data data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here # check validation data data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [1, 4] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_from_fiftyone(tmpdir): tmpdir = Path(tmpdir) @@ -426,23 +435,23 @@ def test_from_fiftyone(tmpdir): # check train data data = next(iter(img_data.train_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [0, 1] # check val data data = next(iter(img_data.val_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [0, 1] # check test data data = next(iter(img_data.test_dataloader())) - imgs, labels = data['input'], data['target'] + imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) assert sorted(list(labels.numpy())) == [0, 1] @@ -460,16 +469,103 @@ def test_from_datasets(): data = next(iter(img_data.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) # check validation data data = next(iter(img_data.val_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, ) + assert labels.shape == (2,) + + +@pytest.fixture +def image_tmpdir(tmpdir): + (tmpdir / "train").mkdir() + Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_1.png")) + Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_2.png")) + return tmpdir / "train" + + +@pytest.fixture +def single_target_csv(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image_1", "target": "Ants"}) + writer.writerow({"image": "image_2", "target": "Bees"}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_csv_single_target(single_target_csv): + img_data = ImageClassificationData.from_csv( + "image", + "target", + train_file=single_target_csv, + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2,) + + +@pytest.fixture +def multi_target_csv(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target_1", "target_2"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image_1", "target_1": 1, "target_2": 0}) + writer.writerow({"image": "image_2", "target_1": 1, "target_2": 1}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_csv_multi_target(multi_target_csv): + img_data = ImageClassificationData.from_csv( + "image", + ["target_1", "target_2"], + train_file=multi_target_csv, + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 2) + + +@pytest.fixture +def bad_csv_no_image(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image_3", "target": "Ants"}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_bad_csv_no_image(bad_csv_no_image): + with pytest.raises(ValueError, match="Found no matches"): + img_data = ImageClassificationData.from_csv( + "image", + ["target"], + train_file=bad_csv_no_image, + batch_size=1, + num_workers=0, + ) + _ = next(iter(img_data.train_dataloader())) diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py index c15aca96ea..ba53d68637 100644 --- a/tests/image/classification/test_data_model_integration.py +++ b/tests/image/classification/test_data_model_integration.py @@ -18,7 +18,7 @@ import torch from flash import Trainer -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from tests.helpers.utils import _IMAGE_TESTING @@ -62,7 +62,7 @@ def test_classification(tmpdir): trainer.finetune(model, datamodule=data, strategy="freeze") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_classification_fiftyone(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 1cbaf589e2..7dc49a3abc 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -19,6 +19,7 @@ import torch from flash import Trainer +from flash.__main__ import main from flash.core.classification import Probabilities from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE @@ -30,11 +31,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { DefaultDataKeys.INPUT: torch.rand(3, 224, 224), - DefaultDataKeys.TARGET: torch.randint(10, size=(1, )).item(), + DefaultDataKeys.TARGET: torch.randint(10, size=(1,)).item(), } def __len__(self) -> int: @@ -42,14 +42,13 @@ def __len__(self) -> int: class DummyMultiLabelDataset(torch.utils.data.Dataset): - def __init__(self, num_classes: int): self.num_classes = num_classes def __getitem__(self, index): return { DefaultDataKeys.INPUT: torch.rand(3, 224, 224), - DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes, )), + DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes,)), } def __len__(self) -> int: @@ -61,17 +60,18 @@ def __len__(self) -> int: @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize( - "backbone", + "backbone,metrics", [ - "resnet18", + ("resnet18", None), + ("resnet18", []), # "resnet34", # "resnet50", # "resnet101", # "resnet152", ], ) -def test_init_train(tmpdir, backbone): - model = ImageClassifier(10, backbone=backbone) +def test_init_train(tmpdir, backbone, metrics): + model = ImageClassifier(10, backbone=backbone, metrics=metrics) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") @@ -117,7 +117,7 @@ def test_multilabel(tmpdir): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") @@ -148,3 +148,13 @@ def test_serve(): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): ImageClassifier.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "image_classification", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 18e2efa1da..50ce9fb196 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import os from pathlib import Path @@ -5,9 +18,8 @@ import pytest from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image.detection.data import ObjectDetectionData -from tests.helpers.utils import _IMAGE_TESTING if _PIL_AVAILABLE: from PIL import Image @@ -19,44 +31,53 @@ def _create_dummy_coco_json(dummy_json_path): dummy_json = { - "images": [{ - "id": 0, - 'width': 1920, - 'height': 1080, - 'file_name': 'sample_one.png', - }, { - "id": 1, - "width": 1920, - "height": 1080, - "file_name": "sample_two.png", - }], - "annotations": [{ - "id": 1, - "image_id": 0, - "category_id": 0, - "area": 150, - "bbox": [30, 40, 20, 20], - "iscrowd": 0, - }, { - "id": 2, - "image_id": 1, - "category_id": 0, - "area": 240, - "bbox": [50, 100, 280, 15], - "iscrowd": 0, - }, { - "id": 3, - "image_id": 1, - "category_id": 0, - "area": 170, - "bbox": [230, 130, 90, 180], - "iscrowd": 0, - }], - "categories": [{ - "id": 0, - "name": "person", - "supercategory": "person", - }] + "images": [ + { + "id": 0, + "width": 1920, + "height": 1080, + "file_name": "sample_one.png", + }, + { + "id": 1, + "width": 1920, + "height": 1080, + "file_name": "sample_two.png", + }, + ], + "annotations": [ + { + "id": 1, + "image_id": 0, + "category_id": 0, + "area": 150, + "bbox": [30, 40, 20, 20], + "iscrowd": 0, + }, + { + "id": 2, + "image_id": 1, + "category_id": 0, + "area": 240, + "bbox": [50, 100, 280, 15], + "iscrowd": 0, + }, + { + "id": 3, + "image_id": 1, + "category_id": 0, + "area": 170, + "bbox": [230, 130, 90, 180], + "iscrowd": 0, + }, + ], + "categories": [ + { + "id": 0, + "name": "person", + "supercategory": "person", + } + ], } with open(dummy_json_path, "w") as fp: @@ -68,8 +89,8 @@ def _create_synth_coco_dataset(tmpdir): train_dir.mkdir() (train_dir / "images").mkdir() - Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_one.png") - Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_two.png") + Image.new("RGB", (1920, 1080)).save(train_dir / "images" / "sample_one.png") + Image.new("RGB", (1920, 1080)).save(train_dir / "images" / "sample_two.png") (train_dir / "annotations").mkdir() dummy_json = train_dir / "annotations" / "sample.json" @@ -85,8 +106,8 @@ def _create_synth_fiftyone_dataset(tmpdir): img_dir = Path(tmpdir / "fo_imgs") img_dir.mkdir() - Image.new('RGB', (1920, 1080)).save(img_dir / "sample_one.png") - Image.new('RGB', (1920, 1080)).save(img_dir / "sample_two.png") + Image.new("RGB", (1920, 1080)).save(img_dir / "sample_one.png") + Image.new("RGB", (1920, 1080)).save(img_dir / "sample_two.png") dataset = fo.Dataset.from_dir( img_dir, @@ -121,20 +142,19 @@ def _create_synth_fiftyone_dataset(tmpdir): return dataset -@pytest.mark.skipif(not _IMAGE_TESTING, reason="pycocotools is not installed for testing") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) - datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) + datamodule = ObjectDetectionData.from_coco( + train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1, image_size=128 + ) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -148,40 +168,30 @@ def test_image_detector_data_from_coco(tmpdir): test_ann_file=coco_ann_path, batch_size=1, num_workers=0, + image_size=128, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_image_detector_data_from_fiftyone(tmpdir): train_dataset = _create_synth_fiftyone_dataset(tmpdir) - datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) + datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1, image_size=128) data = next(iter(datamodule.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) assert datamodule.val_dataloader() is None assert datamodule.test_dataloader() is None @@ -192,20 +202,13 @@ def test_image_detector_data_from_fiftyone(tmpdir): test_dataset=train_dataset, batch_size=1, num_workers=0, + image_size=128, ) data = next(iter(datamodule.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) data = next(iter(datamodule.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - - assert len(imgs) == 1 - assert imgs[0].shape == (3, 1080, 1920) - assert len(labels) == 1 - assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd'] + sample = data[0] + assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index 4c9ce93209..1a9d47b9f0 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -14,9 +14,10 @@ import os import pytest +import torch import flash -from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _PIL_AVAILABLE +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData from tests.helpers.utils import _IMAGE_TESTING @@ -33,49 +34,48 @@ from tests.image.detection.test_data import _create_synth_fiftyone_dataset -@pytest.mark.skipif(not _IMAGE_TESTING, reason="pycocotools is not installed for testing") -@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing") -@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) -def test_detection(tmpdir, model, backbone): +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +def test_detection(tmpdir, head, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1) - model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) - trainer = flash.Trainer(fast_dev_run=True) + trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) - trainer.finetune(model, data) + trainer.finetune(model, data, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") - Image.new('RGB', (512, 512)).save(test_image_one) - Image.new('RGB', (512, 512)).save(test_image_two) + Image.new("RGB", (512, 512)).save(test_image_one) + Image.new("RGB", (512, 512)).save(test_image_two) test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed for testing") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") -@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")]) -def test_detection_fiftyone(tmpdir, model, backbone): +@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +def test_detection_fiftyone(tmpdir, head, backbone): train_dataset = _create_synth_fiftyone_dataset(tmpdir) data = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1) - model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes) + model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes) - trainer = flash.Trainer(fast_dev_run=True) + trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) - trainer.finetune(model, data) + trainer.finetune(model, data, strategy="freeze") test_image_one = os.fspath(tmpdir / "test_one.png") test_image_two = os.fspath(tmpdir / "test_two.png") - Image.new('RGB', (512, 512)).save(test_image_one) - Image.new('RGB', (512, 512)).save(test_image_two) + Image.new("RGB", (512, 512)).save(test_image_one) + Image.new("RGB", (512, 512)).save(test_image_two) test_images = [str(test_image_one), str(test_image_two)] model.predict(test_images) diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index a610122783..f5fd1fba85 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -11,26 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +import random import re +from unittest import mock +import numpy as np import pytest import torch from pytorch_lightning import Trainer -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset +from flash.__main__ import main from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector from tests.helpers.utils import _IMAGE_TESTING +if _ICEVISION_AVAILABLE: + from icevision.data import Prediction + def collate_fn(samples): return {key: [sample[key] for sample in samples] for key in samples[0]} class DummyDetectionDataset(Dataset): - def __init__(self, img_shape, num_boxes, num_classes, length): super().__init__() self.img_shape = img_shape @@ -43,15 +48,27 @@ def __len__(self) -> int: def _random_bbox(self): c, h, w = self.img_shape - xs = torch.randint(w - 1, (2, )) - ys = torch.randint(h - 1, (2, )) - return [min(xs), min(ys), max(xs) + 1, max(ys) + 1] + xs = torch.randint(w - 1, (2,)) + ys = torch.randint(h - 1, (2,)) + return {"xmin": min(xs), "ymin": min(ys), "width": max(xs) - min(xs) + 1, "height": max(ys) - min(ys) + 1} def __getitem__(self, idx): - img = torch.rand(self.img_shape) - boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) - labels = torch.randint(self.num_classes, (self.num_boxes, )) - return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}} + sample = {} + + img = np.random.rand(*self.img_shape).astype(np.float32) + + sample[DefaultDataKeys.INPUT] = img + + sample[DefaultDataKeys.TARGET] = { + "bboxes": [], + "labels": [], + } + + for i in range(self.num_boxes): + sample[DefaultDataKeys.TARGET]["bboxes"].append(self._random_bbox()) + sample[DefaultDataKeys.TARGET]["labels"].append(random.randint(0, self.num_classes - 1)) + + return sample @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -60,48 +77,58 @@ def test_init(): model.eval() batch_size = 2 - ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) - dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + dl = model.process_predict_dataset(ds, batch_size=batch_size) data = next(iter(dl)) - img = data[DefaultDataKeys.INPUT] - out = model(img) + out = model(data) assert len(out) == batch_size - assert {"boxes", "labels", "scores"} <= out[0].keys() + assert all(isinstance(res, Prediction) for res in out) -@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"]) +@pytest.mark.parametrize("head", ["faster_rcnn", "retinanet"]) @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_training(tmpdir, model): - model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False) - ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10) - dl = DataLoader(ds, collate_fn=collate_fn) +def test_training(tmpdir, head): + model = ObjectDetector(num_classes=2, head=head, pretrained=False) + ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) + dl = model.process_train_dataset(ds, 2, 0, False, None) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, dl) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_jit(tmpdir): - path = os.path.join(tmpdir, "test.pt") - - model = ObjectDetector(2) - model.eval() - - model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN - - torch.jit.save(model, path) - model = torch.jit.load(path) - - out = model([torch.rand(3, 32, 32)]) - - # torchvision RCNN always returns a (Losses, Detections) tuple in scripting - out = out[1] - - assert {"boxes", "labels", "scores"} <= out[0].keys() +# TODO: resolve JIT issues +# @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +# def test_jit(tmpdir): +# path = os.path.join(tmpdir, "test.pt") +# +# model = ObjectDetector(2) +# model.eval() +# +# model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN +# +# torch.jit.save(model, path) +# model = torch.jit.load(path) +# +# out = model([torch.rand(3, 32, 32)]) +# +# # torchvision RCNN always returns a (Losses, Detections) tuple in scripting +# out = out[1] +# +# assert {"boxes", "labels", "scores"} <= out[0].keys() @pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.") def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): ObjectDetector.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "object_detection", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/detection/test_serialization.py b/tests/image/detection/test_serialization.py index 93b6a3756b..8f707a229a 100644 --- a/tests/image/detection/test_serialization.py +++ b/tests/image/detection/test_serialization.py @@ -2,13 +2,13 @@ import torch from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image.detection.serialization import FiftyOneDetectionLabels +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") class TestFiftyOneDetectionLabels: - @staticmethod def test_smoke(): serial = FiftyOneDetectionLabels() @@ -16,7 +16,7 @@ def test_smoke(): @staticmethod def test_serialize_fiftyone(): - labels = ['class_1', 'class_2', 'class_3'] + labels = ["class_1", "class_2", "class_3"] serial = FiftyOneDetectionLabels() filepath_serial = FiftyOneDetectionLabels(return_filepath=True) threshold_serial = FiftyOneDetectionLabels(threshold=0.9) @@ -25,8 +25,7 @@ def test_serialize_fiftyone(): sample = { DefaultDataKeys.PREDS: [ { - "boxes": [torch.tensor(20), torch.tensor(30), - torch.tensor(40), torch.tensor(50)], + "boxes": [torch.tensor(20), torch.tensor(30), torch.tensor(40), torch.tensor(50)], "labels": torch.tensor(0), "scores": torch.tensor(0.5), }, diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 2700c3a37e..e823212ef7 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -23,7 +23,7 @@ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/image/instance_segmentation/__init__.py b/tests/image/instance_segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/seq2seq/translation/test_metric.py b/tests/image/instance_segmentation/test_model.py similarity index 55% rename from tests/text/seq2seq/translation/test_metric.py rename to tests/image/instance_segmentation/test_model.py index 86b5784745..8f54742d24 100644 --- a/tests/text/seq2seq/translation/test_metric.py +++ b/tests/image/instance_segmentation/test_model.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest -import torch -from flash.text.seq2seq.translation.metric import BLEUScore +from flash.__main__ import main +from tests.helpers.utils import _IMAGE_TESTING -@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)]) -def test_bleu_score(smooth, expected): - translate_corpus = ['the cat is on the mat'.split()] - reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - metric = BLEUScore(smooth=smooth) - assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4) +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "instance_segmentation", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/keypoint_detection/__init__.py b/tests/image/keypoint_detection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py new file mode 100644 index 0000000000..215ea9a71f --- /dev/null +++ b/tests/image/keypoint_detection/test_model.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +import pytest + +from flash.__main__ import main +from tests.helpers.utils import _IMAGE_TESTING + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "keypoint_detection", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py index 0b2b452e17..4b8fb7a7a7 100644 --- a/tests/image/segmentation/test_backbones.py +++ b/tests/image/segmentation/test_backbones.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -import torch -from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES -@pytest.mark.parametrize(["backbone"], [ - pytest.param("resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), -]) +@pytest.mark.parametrize( + ["backbone"], + [ + pytest.param("resnet50", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + pytest.param("dpn131", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + ], +) def test_semantic_segmentation_backbones_registry(backbone): - img = torch.rand(1, 3, 32, 32) - backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)(pretrained=False) + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)() assert backbone - backbone.eval() - assert backbone(img) is not None + assert isinstance(backbone, str) diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index be898bdff3..b44a68da0d 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -9,7 +9,7 @@ from flash import Trainer from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess from tests.helpers.utils import _IMAGE_TESTING @@ -22,8 +22,8 @@ def build_checkboard(n, m, k=8): x = np.zeros((n, m)) - x[k::k * 2, ::k] = 1 - x[::k * 2, k::k * 2] = 1 + x[k :: k * 2, ::k] = 1 + x[:: k * 2, k :: k * 2] = 1 return x @@ -48,23 +48,22 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup class TestSemanticSegmentationPreprocess: - - @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") @staticmethod + @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") def test_smoke(): prep = SemanticSegmentationPreprocess(num_classes=1) assert prep is not None -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") class TestSemanticSegmentationData: - @staticmethod + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_smoke(): dm = SemanticSegmentationData() assert dm is not None @staticmethod + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_folders(tmpdir): tmp_dir = Path(tmpdir) @@ -86,7 +85,7 @@ def test_from_folders(tmpdir): ] num_classes: int = 2 - img_size: Tuple[int, int] = (196, 196) + img_size: Tuple[int, int] = (128, 128) create_random_data(images, targets, img_size, num_classes) # instantiate the data module @@ -110,22 +109,23 @@ def test_from_folders(tmpdir): # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check val data data = next(iter(dm.val_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check test data data = next(iter(dm.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) @staticmethod + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_folders_warning(tmpdir): tmp_dir = Path(tmpdir) @@ -145,7 +145,7 @@ def test_from_folders_warning(tmpdir): ] num_classes: int = 2 - img_size: Tuple[int, int] = (196, 196) + img_size: Tuple[int, int] = (128, 128) create_random_data(images, targets, img_size, num_classes) # instantiate the data module @@ -164,10 +164,11 @@ def test_from_folders_warning(tmpdir): # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, 196, 196) + assert imgs.shape == (1, 3, 128, 128) + assert labels.shape == (1, 128, 128) @staticmethod + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_files(tmpdir): tmp_dir = Path(tmpdir) @@ -186,7 +187,7 @@ def test_from_files(tmpdir): ] num_classes: int = 2 - img_size: Tuple[int, int] = (196, 196) + img_size: Tuple[int, int] = (128, 128) create_random_data(images, targets, img_size, num_classes) # instantiate the data module @@ -200,7 +201,7 @@ def test_from_files(tmpdir): test_targets=targets, batch_size=2, num_workers=0, - num_classes=num_classes + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -210,22 +211,23 @@ def test_from_files(tmpdir): # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check val data data = next(iter(dm.val_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check test data data = next(iter(dm.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) @staticmethod + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_files_warning(tmpdir): tmp_dir = Path(tmpdir) @@ -244,7 +246,7 @@ def test_from_files_warning(tmpdir): ] num_classes: int = 2 - img_size: Tuple[int, int] = (196, 196) + img_size: Tuple[int, int] = (128, 128) create_random_data(images, targets, img_size, num_classes) # instantiate the data module @@ -255,11 +257,12 @@ def test_from_files_warning(tmpdir): train_targets=targets + [str(tmp_dir / "labels_img4.png")], batch_size=2, num_workers=0, - num_classes=num_classes + num_classes=num_classes, ) - @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @staticmethod + @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_from_fiftyone(tmpdir): tmp_dir = Path(tmpdir) @@ -272,7 +275,7 @@ def test_from_fiftyone(tmpdir): ] num_classes: int = 2 - img_size: Tuple[int, int] = (196, 196) + img_size: Tuple[int, int] = (128, 128) for img_file in images: _rand_image(img_size).save(img_file) @@ -307,27 +310,29 @@ def test_from_fiftyone(tmpdir): # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check val data data = next(iter(dm.val_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check test data data = next(iter(dm.test_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) # check predict data data = next(iter(dm.predict_dataloader())) imgs = data[DefaultDataKeys.INPUT] - assert imgs.shape == (2, 3, 196, 196) + assert imgs.shape == (2, 3, 128, 128) @staticmethod + @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_map_labels(tmpdir): tmp_dir = Path(tmpdir) @@ -351,7 +356,7 @@ def test_map_labels(tmpdir): } num_classes: int = len(labels_map.keys()) - img_size: Tuple[int, int] = (196, 196) + img_size: Tuple[int, int] = (128, 128) create_random_data(images, targets, img_size, num_classes) # instantiate the data module @@ -363,7 +368,7 @@ def test_map_labels(tmpdir): val_targets=targets, batch_size=2, num_workers=0, - num_classes=num_classes + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -379,13 +384,13 @@ def test_map_labels(tmpdir): # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) + assert imgs.shape == (2, 3, 128, 128) + assert labels.shape == (2, 128, 128) assert labels.min().item() == 0 assert labels.max().item() == 1 assert labels.dtype == torch.int64 # now train with `fast_dev_run` - model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fcn") + model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fpn") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, dm, strategy="freeze_unfreeze") diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index ec90b03670..dbc4b3b38e 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -11,26 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock + import pytest import torch -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.image.segmentation import SemanticSegmentation from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS +from tests.helpers.utils import _IMAGE_TESTING @pytest.mark.parametrize( - "head", [ - pytest.param("fcn", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("deeplabv3", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("lraspp", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), - ] + "head", + [ + pytest.param("fpn", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + pytest.param("deeplabv3", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + pytest.param("unet", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), + ], ) def test_semantic_segmentation_heads_registry(head): img = torch.rand(1, 3, 32, 32) backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet50")(pretrained=False) - head = SEMANTIC_SEGMENTATION_HEADS.get(head)(backbone, 10) + head = SEMANTIC_SEGMENTATION_HEADS.get(head)(backbone=backbone, num_classes=10) assert backbone assert head head.eval() @@ -38,3 +42,26 @@ def test_semantic_segmentation_heads_registry(head): if isinstance(res, dict): res = res["out"] assert res.shape[1] == 10 + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@unittest.mock.patch("flash.image.segmentation.heads.smp") +def test_pretrained_weights(mock_smp): + mock_smp.create_model = unittest.mock.MagicMock() + available_weights = SemanticSegmentation.available_pretrained_weights("resnet18") + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet18")() + SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=True) + + kwargs = { + "arch": "unet", + "classes": 10, + "encoder_name": "resnet18", + "in_channels": 3, + "encoder_weights": "imagenet", + } + mock_smp.create_model.assert_called_with(**kwargs) + + for weight in available_weights: + SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=weight) + kwargs["encoder_weights"] = weight + mock_smp.create_model.assert_called_with(**kwargs) diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index c16b54b951..6715ebfc50 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -21,6 +21,7 @@ import torch from flash import Trainer +from flash.__main__ import main from flash.core.data.data_pipeline import DataPipeline from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE @@ -56,12 +57,12 @@ def test_smoke(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize("num_classes", [8, 256]) -@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 127, 212)]) +@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 128, 256)]) def test_forward(num_classes, img_shape): model = SemanticSegmentation( num_classes=num_classes, backbone="resnet50", - head="fcn", + head="fpn", ) B, C, H, W = img_shape @@ -103,28 +104,28 @@ def test_unfreeze(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_predict_tensor(): - img = torch.rand(1, 3, 10, 20) - model = SemanticSegmentation(2) + img = torch.rand(1, 3, 64, 64) + model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], list) - assert len(out[0]) == 10 - assert len(out[0][0]) == 20 + assert len(out[0]) == 64 + assert len(out[0][0]) == 64 @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_predict_numpy(): - img = np.ones((1, 3, 10, 20)) - model = SemanticSegmentation(2) + img = np.ones((1, 3, 64, 64)) + model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], list) - assert len(out[0]) == 10 - assert len(out[0][0]) == 20 + assert len(out[0]) == 64 + assert len(out[0][0]) == 64 @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.trace, (torch.rand(1, 3, 32, 32),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") @@ -155,3 +156,18 @@ def test_serve(): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_available_pretrained_weights(): + assert SemanticSegmentation.available_pretrained_weights("resnet18") == ["imagenet", "ssl", "swsl"] + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "semantic_segmentation", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/segmentation/test_serialization.py b/tests/image/segmentation/test_serialization.py index 09a03ad75c..0e7477348a 100644 --- a/tests/image/segmentation/test_serialization.py +++ b/tests/image/segmentation/test_serialization.py @@ -1,13 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch from flash.core.data.data_source import DefaultDataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE from flash.image.segmentation.serialization import FiftyOneSegmentationLabels, SegmentationLabels +from tests.helpers.utils import _IMAGE_TESTING class TestSemanticSegmentationLabels: - + @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") @staticmethod def test_smoke(): serial = SegmentationLabels() @@ -15,6 +29,7 @@ def test_smoke(): assert serial.labels_map is None assert serial.visualize is False + @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") @staticmethod def test_exception(): serial = SegmentationLabels() @@ -27,6 +42,7 @@ def test_exception(): sample = torch.zeros(2, 3) serial.serialize(sample) + @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") @staticmethod def test_serialize(): serial = SegmentationLabels() @@ -39,6 +55,7 @@ def test_serialize(): assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3 + @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @staticmethod def test_serialize_fiftyone(): @@ -51,9 +68,7 @@ def test_serialize_fiftyone(): sample = { DefaultDataKeys.PREDS: preds, - DefaultDataKeys.METADATA: { - "filepath": "something" - }, + DefaultDataKeys.METADATA: {"filepath": "something"}, } segmentation = serial.serialize(sample) diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index d054986978..8573b70784 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -1,9 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import re +from unittest import mock import pytest import torch +from flash.__main__ import main from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.image.style_transfer import StyleTransfer from tests.helpers.utils import _IMAGE_TESTING @@ -34,6 +49,9 @@ def test_jit(tmpdir): model = StyleTransfer() model.eval() + model.loss_fn = None + model.perceptual_loss = None # TODO: Document this + model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche torch.jit.save(model, path) @@ -48,3 +66,13 @@ def test_jit(tmpdir): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): StyleTransfer.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "style_transfer", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 6036927555..c751426c76 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -14,19 +14,20 @@ import urllib.error import pytest -from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE -from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES +from flash.core.utilities.url_error import catch_url_error +from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES +from tests.helpers.utils import _IMAGE_TESTING -@pytest.mark.parametrize(["backbone", "expected_num_features"], [ - pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")), - pytest.param("simclr-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), - pytest.param("swav-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), - pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), -]) +@pytest.mark.parametrize( + ["backbone", "expected_num_features"], + [ + pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), + pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No timm")), + pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), + ], +) def test_image_classifier_backbones_registry(backbone, expected_num_features): backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone) backbone_model, num_features = backbone_fn(pretrained=False) @@ -34,11 +35,41 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): assert num_features == expected_num_features -def test_pretrained_backbones_catch_url_error(): +@pytest.mark.parametrize( + ["backbone", "pretrained", "expected_num_features"], + [ + pytest.param( + "resnet50", + "supervised", + 2048, + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision"), + ), + pytest.param("resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), + ], +) +def test_pretrained_weights_registry(backbone, pretrained, expected_num_features): + backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone) + backbone_model, num_features = backbone_fn(pretrained=pretrained) + assert backbone_model + assert num_features == expected_num_features + +@pytest.mark.parametrize( + ["backbone", "pretrained"], + [ + pytest.param("resnet50w2", True), + pytest.param("resnet50w4", "supervised"), + ], +) +def test_wide_resnets(backbone, pretrained): + with pytest.raises(KeyError, match=f"Supervised pretrained weights not available for {backbone}"): + IMAGE_CLASSIFIER_BACKBONES.get(backbone)(pretrained=pretrained) + + +def test_pretrained_backbones_catch_url_error(): def raise_error_if_pretrained(pretrained=False): if pretrained: - raise urllib.error.URLError('Test error') + raise urllib.error.URLError("Test error") with pytest.warns(UserWarning, match="Failed to download pretrained weights"): catch_url_error(raise_error_if_pretrained)(pretrained=True) diff --git a/tests/pointcloud/__init__.py b/tests/pointcloud/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pointcloud/detection/__init__.py b/tests/pointcloud/detection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py new file mode 100644 index 0000000000..b337fa28da --- /dev/null +++ b/tests/pointcloud/detection/test_data.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from os.path import join + +import pytest +import torch +from pytorch_lightning import seed_everything + +from flash import Trainer +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.utils import download_data +from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData +from tests.helpers.utils import _POINTCLOUD_TESTING + +if _POINTCLOUD_TESTING: + from flash.pointcloud.detection.open3d_ml.backbones import ObjectDetectBatchCollator + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_pointcloud_object_detection_data(tmpdir): + + seed_everything(52) + + download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_micro.zip", tmpdir) + + dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train")) + + class MockModel(PointCloudObjectDetector): + def training_step(self, batch, batch_idx: int): + assert isinstance(batch, ObjectDetectBatchCollator) + assert len(batch.point) == 2 + assert batch.point[0][1].shape == torch.Size([4]) + assert len(batch.bboxes) > 1 + assert batch.attr[0]["name"] in ("000000.bin", "000001.bin") + assert batch.attr[1]["name"] in ("000000.bin", "000001.bin") + + num_classes = 19 + model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes) + trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0) + trainer.fit(model, dm) + + predict_path = join(tmpdir, "KITTI_Micro", "Kitti", "predict") + model.eval() + + predictions = model.predict([join(predict_path, "scans/000000.bin")]) + assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape[1] == 4 + assert len(predictions[0][DefaultDataKeys.PREDS]) == 158 diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py new file mode 100644 index 0000000000..deafc06faf --- /dev/null +++ b/tests/pointcloud/detection/test_model.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from flash.pointcloud.detection import PointCloudObjectDetector +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_backbones(): + + backbones = PointCloudObjectDetector.available_backbones() + assert backbones == ["pointpillars", "pointpillars_kitti"] diff --git a/tests/pointcloud/segmentation/__init__.py b/tests/pointcloud/segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py new file mode 100644 index 0000000000..a4c808fff2 --- /dev/null +++ b/tests/pointcloud/segmentation/test_data.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from os.path import join + +import pytest +import torch +from pytorch_lightning import seed_everything + +from flash import Trainer +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.utils import download_data +from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_pointcloud_segmentation_data(tmpdir): + + seed_everything(52) + + download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiMicro.zip", tmpdir) + + dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train")) + + class MockModel(PointCloudSegmentation): + def training_step(self, batch, batch_idx: int): + assert batch[DefaultDataKeys.INPUT]["xyz"][0].shape == torch.Size([2, 45056, 3]) + assert batch[DefaultDataKeys.INPUT]["xyz"][1].shape == torch.Size([2, 11264, 3]) + assert batch[DefaultDataKeys.INPUT]["xyz"][2].shape == torch.Size([2, 2816, 3]) + assert batch[DefaultDataKeys.INPUT]["xyz"][3].shape == torch.Size([2, 704, 3]) + assert batch[DefaultDataKeys.INPUT]["labels"].shape == torch.Size([2, 45056]) + assert batch[DefaultDataKeys.INPUT]["labels"].max() == 19 + assert batch[DefaultDataKeys.INPUT]["labels"].min() == 0 + assert batch[DefaultDataKeys.METADATA][0]["name"] in ("00_000000", "00_000001") + assert batch[DefaultDataKeys.METADATA][1]["name"] in ("00_000000", "00_000001") + + num_classes = 19 + model = MockModel(backbone="randlanet", num_classes=num_classes) + trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0) + trainer.fit(model, dm) + + predictions = model.predict(join(tmpdir, "SemanticKittiMicro", "predict")) + assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape == torch.Size([45056, 3]) + assert torch.stack(predictions[0][DefaultDataKeys.PREDS]).shape == torch.Size([45056, 19]) + assert torch.stack(predictions[0][DefaultDataKeys.TARGET]).shape == torch.Size([45056]) diff --git a/tests/pointcloud/segmentation/test_datasets.py b/tests/pointcloud/segmentation/test_datasets.py new file mode 100644 index 0000000000..fa36606a26 --- /dev/null +++ b/tests/pointcloud/segmentation/test_datasets.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import pytest + +from flash.pointcloud.segmentation.datasets import LyftDataset, SemanticKITTIDataset +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@patch("flash.pointcloud.segmentation.datasets.os.system") +def test_datasets(mock_system): + + LyftDataset("data") + assert mock_system.call_count == 2 + assert "lyft" in mock_system.call_args_list[0][0][0] + assert "data" in mock_system.call_args_list[0][0][0] + assert "lyft" in mock_system.call_args_list[1][0][0] + assert "data" in mock_system.call_args_list[1][0][0] + + mock_system.reset_mock() + SemanticKITTIDataset("data") + assert mock_system.call_count == 1 + assert "semantickitti" in mock_system.call_args_list[0][0][0] + assert "data" in mock_system.call_args_list[0][0][0] diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py new file mode 100644 index 0000000000..234f867e64 --- /dev/null +++ b/tests/pointcloud/segmentation/test_model.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from flash.pointcloud.segmentation import PointCloudSegmentation +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_backbones(): + + backbones = PointCloudSegmentation.available_backbones() + assert backbones == ["randlanet", "randlanet_s3dis", "randlanet_semantic_kitti", "randlanet_toronto3d"] + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.parametrize( + "backbone", + [ + "randlanet", + "randlanet_s3dis", + "randlanet_toronto3d", + "randlanet_semantic_kitti", + ], +) +def test_models(backbone): + num_classes = 13 + model = PointCloudSegmentation(backbone=backbone, num_classes=num_classes) + assert model.head.weight.shape == torch.Size([13, 32]) diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index baa87b3451..b1e9ef3f25 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -23,7 +23,7 @@ if _PANDAS_AVAILABLE: import pandas as pd - from flash.tabular import TabularData + from flash.tabular import TabularClassificationData from flash.tabular.classification.utils import _categorize, _normalize TEST_DF_1 = pd.DataFrame( @@ -68,24 +68,24 @@ def test_normalize(): @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") -def test_emb_sizes(): +def test_embedding_sizes(): self = Mock() self.codes = {"category": [None, "a", "b", "c"]} self.cat_cols = ["category"] # use __get__ to test property with mocked self - es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101 assert es == [(4, 16)] self.codes = {} self.cat_cols = [] # use __get__ to test property with mocked self - es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101 assert es == [] self.codes = {"large": ["a"] * 100_000, "larger": ["b"] * 1_000_000} self.cat_cols = ["large", "larger"] # use __get__ to test property with mocked self - es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101 assert es == [(100_000, 17), (1_000_000, 31)] @@ -94,7 +94,7 @@ def test_tabular_data(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() - dm = TabularData.from_data_frame( + dm = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -110,7 +110,7 @@ def test_tabular_data(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") @@ -122,7 +122,7 @@ def test_categorical_target(tmpdir): # change int label to string df["label"] = df["label"].astype(str) - dm = TabularData.from_data_frame( + dm = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -138,7 +138,7 @@ def test_categorical_target(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") @@ -146,7 +146,7 @@ def test_from_data_frame(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() - dm = TabularData.from_data_frame( + dm = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -154,7 +154,7 @@ def test_from_data_frame(tmpdir): val_data_frame=val_data_frame, test_data_frame=test_data_frame, num_workers=0, - batch_size=1 + batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) @@ -162,7 +162,7 @@ def test_from_data_frame(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") @@ -173,7 +173,7 @@ def test_from_csv(tmpdir): TEST_DF_2.to_csv(val_csv) TEST_DF_2.to_csv(test_csv) - dm = TabularData.from_csv( + dm = TabularClassificationData.from_csv( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -181,7 +181,7 @@ def test_from_csv(tmpdir): val_file=str(val_csv), test_file=str(test_csv), num_workers=0, - batch_size=1 + batch_size=1, ) for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: data = next(iter(dl)) @@ -189,14 +189,14 @@ def test_from_csv(tmpdir): target = data[DefaultDataKeys.TARGET] assert cat.shape == (1, 1) assert num.shape == (1, 2) - assert target.shape == (1, ) + assert target.shape == (1,) @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") def test_empty_inputs(): train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): - TabularData.from_data_frame( + TabularClassificationData.from_data_frame( numerical_fields=None, categorical_fields=None, target_fields="label", diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index 349aeeaaba..3d4875f1dd 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -15,7 +15,7 @@ import pytorch_lightning as pl from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassifier, TabularData +from flash.tabular import TabularClassificationData, TabularClassifier from tests.helpers.utils import _TABULAR_TESTING if _TABULAR_AVAILABLE: @@ -37,7 +37,7 @@ def test_classification(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_1.copy() test_data_frame = TEST_DF_1.copy() - data = TabularData.from_data_frame( + data = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -47,6 +47,6 @@ def test_classification(tmpdir): num_workers=0, batch_size=2, ) - model = TabularClassifier(num_features=3, num_classes=2, embedding_sizes=data.emb_sizes) + model = TabularClassifier(num_features=3, num_classes=2, embedding_sizes=data.embedding_sizes) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, data) diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index d3cc3db332..e7ee5e9f5d 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -21,23 +21,21 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassifier -from flash.tabular.classification.data import TabularData +from flash.tabular import TabularClassificationData, TabularClassifier from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING # ======== Mock functions ======== class DummyDataset(torch.utils.data.Dataset): - def __init__(self, num_num=16, num_cat=16): super().__init__() self.num_num = num_num self.num_cat = num_cat def __getitem__(self, index): - target = torch.randint(0, 10, size=(1, )).item() - cat_vars = torch.randint(0, 10, size=(self.num_cat, )) + target = torch.randint(0, 10, size=(1,)).item() + cat_vars = torch.randint(0, 10, size=(self.num_cat,)) num_vars = torch.rand(self.num_num) return {DefaultDataKeys.INPUT: (cat_vars, num_vars), DefaultDataKeys.TARGET: target} @@ -84,7 +82,7 @@ def test_jit(tmpdir): model.eval() # torch.jit.script doesn't work with tabnet - model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)), )) + model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)),)) # TODO: torch.jit.save doesn't work with tabnet # path = os.path.join(tmpdir, "test.pt") @@ -100,7 +98,7 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]} - datamodule = TabularData.from_data_frame( + datamodule = TabularClassificationData.from_data_frame( "cat_col", "num_col", "target", diff --git a/tests/template/__init__.py b/tests/template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py index 6bdec2f2ef..b793849e08 100644 --- a/tests/template/classification/test_data.py +++ b/tests/template/classification/test_data.py @@ -49,7 +49,7 @@ def test_smoke(): def test_from_numpy(self): """Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method.""" data = np.random.rand(10, self.num_features) - targets = np.random.randint(0, self.num_classes, (10, )) + targets = np.random.randint(0, self.num_classes, (10,)) # instantiate the data module dm = TemplateData.from_numpy( @@ -71,19 +71,19 @@ def test_from_numpy(self): data = next(iter(dm.train_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, self.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check val data data = next(iter(dm.val_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, self.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check test data data = next(iter(dm.test_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, self.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) @staticmethod def test_from_sklearn(): @@ -107,16 +107,16 @@ def test_from_sklearn(): data = next(iter(dm.train_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, dm.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check val data data = next(iter(dm.val_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, dm.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) # check test data data = next(iter(dm.test_dataloader())) rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert rows.shape == (2, dm.num_features) - assert targets.shape == (2, ) + assert targets.shape == (2,) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index 9fa57b80b9..cfd0f77f39 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -39,7 +39,7 @@ class DummyDataset(torch.utils.data.Dataset): def __getitem__(self, index): return { DefaultDataKeys.INPUT: torch.randn(self.num_features), - DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1, ))[0], + DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1,))[0], } def __len__(self) -> int: @@ -121,7 +121,7 @@ def test_predict_sklearn(): @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16), ))]) +@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "test.pt") diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index d5a3b680f9..4c42909b35 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -44,6 +44,12 @@ {"sentence": "this is a sentence three","lab":0} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"sentence": "this is a sentence one","lab":0}, +{"sentence": "this is a sentence two","lab":1}, +{"sentence": "this is a sentence three","lab":0}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -57,6 +63,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -78,7 +90,7 @@ def test_test_valid(tmpdir): train_file=csv_path, val_file=csv_path, test_file=csv_path, - batch_size=1 + batch_size=1, ) batch = next(iter(dm.val_dataloader())) assert batch["labels"].item() in [0, 1] @@ -99,6 +111,18 @@ def test_from_json(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = TextClassificationData.from_json( + "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert batch["labels"].item() in [0, 1] + assert "input_ids" in batch + + @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): @@ -111,9 +135,7 @@ def test_text_module_not_found_error(): "cls, kwargs", [ (TextDataSource, {}), - (TextFileDataSource, { - "filetype": "csv" - }), + (TextFileDataSource, {"filetype": "csv"}), (TextCSVDataSource, {}), (TextJSONDataSource, {}), (TextSentencesDataSource, {}), diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 431b8f4cb8..7ca20d92c7 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -19,6 +19,7 @@ import torch from flash import Trainer +from flash.__main__ import main from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassifier from flash.text.classification.data import TextClassificationPostprocess, TextClassificationPreprocess @@ -28,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(100, )), - "labels": torch.randint(2, size=(1, )).item(), + "input_ids": torch.randint(1000, size=(100,)), + "labels": torch.randint(2, size=(1,)).item(), } def __len__(self) -> int: @@ -87,3 +87,19 @@ def test_serve(): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")): TextClassifier.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.parametrize( + "cli_args", + ( + ["flash", "text_classification", "--trainer.fast_dev_run", "True"], + ["flash", "text_classification", "--trainer.fast_dev_run", "True", "from_toxic"], + ), +) +def test_cli(cli_args): + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py new file mode 100644 index 0000000000..01d987e092 --- /dev/null +++ b/tests/text/classification/test_ort.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch +from pytorch_lightning import Callback +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash import Trainer +from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE +from flash.text import TextClassifier +from flash.text.ort_callback import ORTCallback +from tests.helpers.boring_model import BoringModel +from tests.helpers.utils import _TEXT_TESTING +from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE + +if _TORCH_ORT_AVAILABLE: + from torch_ort import ORTModule + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_init_train_enable_ort(tmpdir): + class TestCallback(Callback): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + assert isinstance(pl_module.model, ORTModule) + + model = TextClassifier(2, TEST_BACKBONE, enable_ort=True) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback()) + trainer.fit( + model, + train_dataloader=torch.utils.data.DataLoader(DummyDataset()), + val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), + ) + trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset())) + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") +def test_ort_callback_fails_no_model(tmpdir): + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) + with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): + trainer.fit( + model, + train_dataloader=torch.utils.data.DataLoader(DummyDataset()), + val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), + ) diff --git a/tests/text/seq2seq/__init__.py b/tests/text/seq2seq/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py index 4f2144aa90..d52bd9132a 100644 --- a/tests/text/seq2seq/core/test_data.py +++ b/tests/text/seq2seq/core/test_data.py @@ -36,22 +36,11 @@ @pytest.mark.parametrize( "cls, kwargs", [ - (Seq2SeqDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), - (Seq2SeqFileDataSource, { - "backbone": "sshleifer/tiny-mbart", - "filetype": "csv" - }), - (Seq2SeqCSVDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), - (Seq2SeqJSONDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), - (Seq2SeqSentencesDataSource, { - "backbone": "sshleifer/tiny-mbart" - }), + (Seq2SeqDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqFileDataSource, {"backbone": "sshleifer/tiny-mbart", "filetype": "csv"}), + (Seq2SeqCSVDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqJSONDataSource, {"backbone": "sshleifer/tiny-mbart"}), + (Seq2SeqSentencesDataSource, {"backbone": "sshleifer/tiny-mbart"}), (Seq2SeqPostprocess, {}), ], ) diff --git a/tests/text/seq2seq/summarization/test_metric.py b/tests/text/seq2seq/core/test_metrics.py similarity index 67% rename from tests/text/seq2seq/summarization/test_metric.py rename to tests/text/seq2seq/core/test_metrics.py index 9f17397b02..c16f828c37 100644 --- a/tests/text/seq2seq/summarization/test_metric.py +++ b/tests/text/seq2seq/core/test_metrics.py @@ -14,7 +14,7 @@ import pytest import torch -from flash.text.seq2seq.summarization.metric import RougeMetric +from flash.text.seq2seq.core.metrics import BLEUScore, RougeMetric from tests.helpers.utils import _TEXT_TESTING @@ -24,3 +24,11 @@ def test_rouge(): target = "Is your name John".split() metric = RougeMetric() assert torch.allclose(torch.tensor(metric(preds, target)["rouge1_recall"]).float(), torch.tensor(0.25), 1e-4) + + +@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)]) +def test_bleu_score(smooth, expected): + translate_corpus = ["the cat is on the mat".split()] + reference_corpus = [["there is a cat on the mat".split(), "a cat is on the mat".split()]] + metric = BLEUScore(smooth=smooth) + assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4) diff --git a/tests/text/seq2seq/question_answering/__init__.py b/tests/text/seq2seq/question_answering/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/seq2seq/question_answering/test_data.py b/tests/text/seq2seq/question_answering/test_data.py new file mode 100644 index 0000000000..8879282bba --- /dev/null +++ b/tests/text/seq2seq/question_answering/test_data.py @@ -0,0 +1,131 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import pytest + +from flash.text import QuestionAnsweringData +from tests.helpers.utils import _TEXT_TESTING + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + +TEST_CSV_DATA = """input,target +this is a question one,this is an answer one +this is a question two,this is an answer two +this is a question three,this is an answer three +""" + +TEST_JSON_DATA = """ +{"input": "this is a question one","target":"this is an answer one"} +{"input": "this is a question two","target":"this is an answer two"} +{"input": "this is a question three","target":"this is an answer three"} +""" + +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a question one","target":"this is an answer one"}, +{"input": "this is a question two","target":"this is an answer two"}, +{"input": "this is a question three","target":"this is an answer three"}]} +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA) + return path + + +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_csv(tmpdir): + csv_path = csv_data(tmpdir) + dm = QuestionAnsweringData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_files(tmpdir): + csv_path = csv_data(tmpdir) + dm = QuestionAnsweringData.from_csv( + "input", + "target", + backbone=TEST_BACKBONE, + train_file=csv_path, + val_file=csv_path, + test_file=csv_path, + batch_size=1, + ) + batch = next(iter(dm.val_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + batch = next(iter(dm.test_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_postprocess_tokenizer(tmpdir): + """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different + backbone is used.""" + backbone = "sshleifer/bart-tiny-random" + csv_path = csv_data(tmpdir) + dm = QuestionAnsweringData.from_csv( + "input", + "target", + backbone=backbone, + train_file=csv_path, + batch_size=1, + ) + pipeline = dm.data_pipeline + pipeline.initialize() + assert pipeline._postprocess_pipeline.backbone == backbone + assert pipeline._postprocess_pipeline.tokenizer is not None + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json(tmpdir): + json_path = json_data(tmpdir) + dm = QuestionAnsweringData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = QuestionAnsweringData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/question_answering/test_model.py b/tests/text/seq2seq/question_answering/test_model.py new file mode 100644 index 0000000000..ad4389b768 --- /dev/null +++ b/tests/text/seq2seq/question_answering/test_model.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from unittest import mock + +import pytest +import torch + +from flash import Trainer +from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text import QuestionAnsweringTask +from flash.text.seq2seq.core.data import Seq2SeqPostprocess +from flash.text.seq2seq.question_answering.data import QuestionAnsweringPreprocess +from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + def __getitem__(self, index): + return { + "input_ids": torch.randint(1000, size=(128,)), + "labels": torch.randint(1000, size=(128,)), + } + + def __len__(self) -> int: + return 100 + + +# ============================== + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_init_train(tmpdir): + model = QuestionAnsweringTask(TEST_BACKBONE) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_jit(tmpdir): + sample_input = { + "input_ids": torch.randint(1000, size=(1, 32)), + "attention_mask": torch.randint(1, size=(1, 32)), + } + path = os.path.join(tmpdir, "test.pt") + + model = QuestionAnsweringTask(TEST_BACKBONE) + model.eval() + + # Huggingface only supports `torch.jit.trace` + model = torch.jit.trace(model, [sample_input]) + + torch.jit.save(model, path) + model = torch.jit.load(path) + + out = model(sample_input) + assert isinstance(out, torch.Tensor) + + +@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@mock.patch("flash._IS_TESTING", True) +def test_serve(): + model = QuestionAnsweringTask(TEST_BACKBONE) + # TODO: Currently only servable once a preprocess and postprocess have been attached + model._preprocess = QuestionAnsweringPreprocess(backbone=TEST_BACKBONE) + model._postprocess = Seq2SeqPostprocess() + model.eval() + model.serve() + + +@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") +def test_load_from_checkpoint_dependency_error(): + with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")): + QuestionAnsweringTask.load_from_checkpoint("not_a_real_checkpoint.pt") diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 2ab09f3636..ff359dcdf0 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -22,15 +22,21 @@ TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing TEST_CSV_DATA = """input,target -this is a sentence one,this is a translated sentence one -this is a sentence two,this is a translated sentence two -this is a sentence three,this is a translated sentence three +this is a sentence one,this is a summarized sentence one +this is a sentence two,this is a summarized sentence two +this is a sentence three,this is a summarized sentence three """ TEST_JSON_DATA = """ -{"input": "this is a sentence one","target":"this is a translated sentence one"} -{"input": "this is a sentence two","target":"this is a translated sentence two"} -{"input": "this is a sentence three","target":"this is a translated sentence three"} +{"input": "this is a sentence one","target":"this is a summarized sentence one"} +{"input": "this is a sentence two","target":"this is a summarized sentence two"} +{"input": "this is a sentence three","target":"this is a summarized sentence three"} +""" + +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a sentence one","target":"this is a summarized sentence one"}, +{"input": "this is a sentence two","target":"this is a summarized sentence two"}, +{"input": "this is a sentence three","target":"this is a summarized sentence three"}]} """ @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -80,9 +92,8 @@ def test_from_files(tmpdir): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_postprocess_tokenizer(tmpdir): - """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is - used. - """ + """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different + backbone is used.""" backbone = "sshleifer/bart-tiny-random" csv_path = csv_data(tmpdir) dm = SummarizationData.from_csv( @@ -106,3 +117,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = SummarizationData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index ccff5e6d85..c6adf69fdc 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -29,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(128, )), - "labels": torch.randint(1000, size=(128, )), + "input_ids": torch.randint(1000, size=(128,)), + "labels": torch.randint(1000, size=(128,)), } def __len__(self) -> int: diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index 244cb27d4a..f87a51fdcd 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -33,6 +33,12 @@ {"input": "this is a sentence three","target":"this is a translated sentence three"} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a sentence one","target":"this is a translated sentence one"}, +{"input": "this is a sentence two","target":"this is a translated sentence two"}, +{"input": "this is a sentence three","target":"this is a translated sentence three"}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -67,7 +79,7 @@ def test_from_files(tmpdir): train_file=csv_path, val_file=csv_path, test_file=csv_path, - batch_size=1 + batch_size=1, ) batch = next(iter(dm.val_dataloader())) assert "labels" in batch @@ -86,3 +98,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = TranslationData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index c49ccd4c24..237fa3bb5a 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -29,11 +29,10 @@ class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): return { - "input_ids": torch.randint(1000, size=(128, )), - "labels": torch.randint(1000, size=(128, )), + "input_ids": torch.randint(1000, size=(128,)), + "labels": torch.randint(1000, size=(128,)), } def __len__(self) -> int: diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index 27ad049411..d7d45aa69f 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -16,12 +16,14 @@ import re import tempfile from pathlib import Path +from unittest import mock import pytest import torch from torch.utils.data import SequentialSampler import flash +from flash.__main__ import main from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _VIDEO_AVAILABLE from flash.video import VideoClassificationData, VideoClassifier from tests.helpers.utils import _VIDEO_TESTING @@ -43,7 +45,7 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int): for i in range(num_frames): xc = float(i) / num_frames yc = 1 - float(i) / (2 * num_frames) - d = torch.exp(-((x - xc)**2 + (y - yc)**2) / 2) * 255 + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) return torch.stack(data, 0) @@ -51,9 +53,9 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int): # https://github.com/facebookresearch/pytorchvideo/blob/4feccb607d7a16933d485495f91d067f177dd8db/tests/utils.py#L33 @contextlib.contextmanager def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=None, directory=None): - """ - Creates a temporary lossless, mp4 video with synthetic content. Uses a context which - deletes the video after exit. + """Creates a temporary lossless, mp4 video with synthetic content. + + Uses a context which deletes the video after exit. """ # Lossless options. video_codec = "libx264rgb" @@ -101,8 +103,8 @@ def mock_encoded_video_dataset_file(): @contextlib.contextmanager def mock_encoded_video_dataset_folder(tmpdir): - """ - Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2. + """Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2. + Returns a directory that to this mock encoded video dataset and the video duration in seconds. """ num_frames = 10 @@ -150,28 +152,34 @@ def test_video_classifier_finetune(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([ - UniformTemporalSubsample(8), - RandomShortSideScale(min_size=256, max_size=320), - RandomCrop(244), - RandomHorizontalFlip(p=0.5), - ]), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), + "post_tensor_transform": Compose( + [ + ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ] + ), + ), + ] + ), + "per_batch_transform_on_device": Compose( + [ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), + data_format="BCTHW", + same_on_frame=False, + ), + ), + ] + ), } datamodule = VideoClassificationData.from_folders( @@ -180,17 +188,17 @@ def test_video_classifier_finetune(tmpdir): clip_duration=half_duration, video_sampler=SequentialSampler, decode_audio=False, - train_transform=train_transform + train_transform=train_transform, ) - model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False) + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") - trainer = flash.Trainer(fast_dev_run=True) + trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_video_classifier_finetune_fiftyone(tmpdir): @@ -220,28 +228,34 @@ def test_video_classifier_finetune_fiftyone(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose([ - ApplyTransformToKey( - key="video", - transform=Compose([ - UniformTemporalSubsample(8), - RandomShortSideScale(min_size=256, max_size=320), - RandomCrop(244), - RandomHorizontalFlip(p=0.5), - ]), - ), - ]), - "per_batch_transform_on_device": Compose([ - ApplyTransformToKey( - key="video", - transform=K.VideoSequential( - K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), - K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), - data_format="BCTHW", - same_on_frame=False - ) - ), - ]), + "post_tensor_transform": Compose( + [ + ApplyTransformToKey( + key="video", + transform=Compose( + [ + UniformTemporalSubsample(8), + RandomShortSideScale(min_size=256, max_size=320), + RandomCrop(244), + RandomHorizontalFlip(p=0.5), + ] + ), + ), + ] + ), + "per_batch_transform_on_device": Compose( + [ + ApplyTransformToKey( + key="video", + transform=K.VideoSequential( + K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])), + K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), + data_format="BCTHW", + same_on_frame=False, + ), + ), + ] + ), } datamodule = VideoClassificationData.from_fiftyone( @@ -250,12 +264,12 @@ def test_video_classifier_finetune_fiftyone(tmpdir): clip_duration=half_duration, video_sampler=SequentialSampler, decode_audio=False, - train_transform=train_transform + train_transform=train_transform, ) - model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False) + model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50") - trainer = flash.Trainer(fast_dev_run=True) + trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule) @@ -265,7 +279,7 @@ def test_jit(tmpdir): sample_input = torch.rand(1, 3, 32, 256, 256) path = os.path.join(tmpdir, "test.pt") - model = VideoClassifier(2, pretrained=False) + model = VideoClassifier(2, pretrained=False, backbone="slow_r50") model.eval() # pytorchvideo only works with `torch.jit.trace` @@ -283,3 +297,13 @@ def test_jit(tmpdir): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[video]'")): VideoClassifier.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +def test_cli(): + cli_args = ["flash", "video_classification", "--trainer.fast_dev_run", "True", "num_workers", "0"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass