diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index 7ea5aa01f5..942baaa555 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -5,7 +5,7 @@ updates:
# Enable version updates for python
- package-ecosystem: "pip"
# Look for a `requirements` in the `root` directory
- directory: "/"
+ directory: "/requirements"
# Check for updates once a week
schedule:
interval: "weekly"
diff --git a/.github/mergify.yml b/.github/mergify.yml
index 6b8badfda1..49352692cd 100644
--- a/.github/mergify.yml
+++ b/.github/mergify.yml
@@ -55,4 +55,4 @@ pull_request_rules:
actions:
request_reviews:
teams:
- - "@Lightning-AI/core-flash"
+ - "@Lightning-Universe/core-Flash"
diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml
index 86e10b29b4..6ee9182386 100644
--- a/.github/workflows/ci-checks.yml
+++ b/.github/workflows/ci-checks.yml
@@ -8,13 +8,13 @@ on:
jobs:
check-schema:
- uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.8.0
+ uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.9.0
with:
# todo: validation has some problem with `- ${{ each topic in parameters.domains }}:` construct
azure-dir: ""
check-package:
- uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.8.0
+ uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.9.0
with:
actions-ref: v0.8.0
artifact-name: dist-packages-${{ github.sha }}
diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml
index 6c9496c147..9dc5e1f5e1 100644
--- a/.github/workflows/ci-testing.yml
+++ b/.github/workflows/ci-testing.yml
@@ -4,8 +4,7 @@ name: CI testing
on: # Trigger the workflow on push or pull request, but only for the master branch
push:
branches: ["master", "release/*"]
- pull_request:
- branches: ["master", "release/*"]
+ pull_request: {}
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
@@ -16,8 +15,8 @@ defaults:
shell: bash
jobs:
- pytester:
+ pytester:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
@@ -42,13 +41,18 @@ jobs:
- { os: 'ubuntu-20.04', python-version: 3.9, topic: 'graph', extra: [] }
- { os: 'ubuntu-20.04', python-version: 3.9, topic: 'audio', extra: [] }
- { os: 'ubuntu-20.04', python-version: 3.8, topic: 'core', extra: [], requires: 'oldest' }
- - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'serve', extra: [], requires: 'oldest' }
+ - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'image', extra: [], requires: 'oldest' }
- { os: 'ubuntu-20.04', python-version: 3.8, topic: 'vision', extra: [], requires: 'oldest' }
+ - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'tabular', extra: [], requires: 'oldest' }
+ - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'text', extra: [], requires: 'oldest' }
+ #- { os: 'ubuntu-20.04', python-version: 3.8, topic: 'serve', extra: [], requires: 'oldest' } # todo
# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 50
env:
FREEZE_REQUIREMENTS: 1
+ TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html
+ TRANSFORMERS_CACHE: _hf_cache
steps:
- uses: actions/checkout@v3
@@ -67,7 +71,6 @@ jobs:
- name: Setup macOS
if: runner.os == 'macOS'
run: brew install libomp openblas lapack
-
- name: Setup Ubuntu
if: runner.os == 'Linux'
run: sudo apt-get install -y libsndfile1 graphviz
@@ -76,8 +79,7 @@ jobs:
if: matrix.requires == 'oldest'
run: |
import glob, os
- files = glob.glob(os.path.join("requirements", "*.txt")) + ['requirements.txt']
- files = ['requirements.txt']
+ files = glob.glob(os.path.join("requirements", "*.txt"))
for fname in files:
lines = [line.replace('>=', '==') for line in open(fname).readlines()]
open(fname, 'w').writelines(lines)
@@ -91,30 +93,57 @@ jobs:
gh_env.write(f"EXTRAS={','.join(extras)}")
shell: python
+ - name: Get pip cache dir
+ id: pip-cache
+ run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
+ - name: Restore pip cache
+ uses: actions/cache/restore@v3
+ id: restore-cache
+ with:
+ path: ${{ steps.pip-cache.outputs.dir }}
+ key: pip-dependencies
+
+ - name: Install package
+ run: |
+ # todo: some dependency has not correct format of their extras
+ python -m pip install "pip==22.3.1"
+ # todo: this is a hack to be able to install packages that are checking torch version while install
+ pip install numpy Cython "torch>=1.11.0" -f $TORCH_URL
+ pip install .[$EXTRAS,test] --upgrade \
+ --prefer-binary \
+ -f $TORCH_URL \
+ -f https://data.pyg.org/whl/torch-1.13.1+cpu.html # this extra URL is for graph extras
+ pip list
+
+ - name: Restore HF cache
+ uses: actions/cache/restore@v3
+ with:
+ path: ${{ env.TRANSFORMERS_CACHE }}
+ key: cache-transformers
+
+ - name: DocTests
+ working-directory: src/
+ run: |
+ mv flash flashy
+ pytest . --doctest-modules --doctest-plus
+ mv flashy flash
+
- name: Install dependencies
- env:
- TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html
run: |
- python -m pip install "pip==22.2.1"
- pip install numpy Cython "torch>=1.7.1" -f $TORCH_URL
pip install .[$EXTRAS,test] \
-r requirements/testing_${{ matrix.topic }}.txt \
--upgrade --prefer-binary -f $TORCH_URL
+ pip cache info
pip list
- - name: Cache datasets
- uses: actions/cache@v3
+ - name: Save pip cache
+ if: github.ref == 'refs/heads/master'
+ uses: actions/cache/save@v3
with:
- path: data # This path is specific to Ubuntu
- key: flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }}
- restore-keys: flash-datasets-
-
- # ToDO
- #- name: DocTests
- # run: |
- # pytest src/ -vv # --reruns 3 --reruns-delay 2
+ path: ${{ steps.pip-cache.outputs.dir }}
+ key: pip-dependencies
- - name: Tests
+ - name: Testing
run: |
coverage run --source flash -m pytest \
tests/core \
@@ -124,6 +153,13 @@ jobs:
tests/${{ matrix.topic }} \
-v # --reruns 3 --reruns-delay 2
+ - name: Save HF cache
+ if: github.ref == 'refs/heads/master'
+ uses: actions/cache/save@v3
+ with:
+ path: ${{ env.TRANSFORMERS_CACHE }}
+ key: cache-transformers
+
- name: Statistics
run: |
coverage report
@@ -134,7 +170,7 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
- flags: unittests
+ flags: unittests,${{ matrix.topic }},${{ matrix.extra }}
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml
index c4c4e26d07..c1b33215e6 100644
--- a/.github/workflows/docs-check.yml
+++ b/.github/workflows/docs-check.yml
@@ -2,17 +2,27 @@ name: "Check Docs"
# https://github.com/marketplace/actions/sphinx-build
on: # Trigger the workflow on push or pull request, but only for the master branch
- push: {}
- pull_request:
+ push:
branches: [master]
+ pull_request: {}
env:
FREEZE_REQUIREMENTS: 1
+ TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html
+ TRANSFORMERS_CACHE: _hf_cache
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
+ cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
+
+defaults:
+ run:
+ shell: bash
jobs:
+
make-docs:
runs-on: ubuntu-20.04
-
steps:
- uses: actions/checkout@v3
with:
@@ -21,28 +31,29 @@ jobs:
with:
python-version: 3.8
- # Note: This uses an internal pip API and may not always work
- # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- name: Cache pip
uses: actions/cache@v3
with:
- path: ~/.cache/pip
- key: pip-${{ hashFiles('requirements.txt') }}
+ path: ~/.cache/pip # this is specific for Ubuntu
+ key: pip-${{ hashFiles('requirements/*.txt') }}
restore-keys: pip-
-
- name: Install dependencies
run: |
sudo apt-get update --fix-missing
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
sudo apt-get install -y cmake pandoc texlive-latex-extra dvipng texlive-pictures
pip --version
- pip install . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install --requirement requirements/docs.txt
+ pip install -e . -r requirements/docs.txt -f $TORCH_URL
pip list
- shell: bash
+
+ - name: Cache transformers
+ uses: actions/cache@v3
+ with:
+ path: ${{ env.TRANSFORMERS_CACHE }}
+ key: cache-transformers
- name: Make Documentation
- working-directory: docs
+ working-directory: docs/
run: make html --debug --jobs 2 SPHINXOPTS="-W --keep-going"
- name: Upload built docs
@@ -51,48 +62,42 @@ jobs:
name: docs-results-${{ github.sha }}
path: docs/build/html/
+
test-docs:
runs-on: ubuntu-20.04
-
steps:
- - uses: actions/checkout@v3
- with:
- submodules: true
- - uses: actions/setup-python@v4
- with:
- python-version: 3.8
-
- # Note: This uses an internal pip API and may not always work
- # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- - name: Cache pip
- uses: actions/cache@v3
- with:
- path: ~/.cache/pip
- key: pip-${{ hashFiles('requirements/base.txt') }}
- restore-keys: pip-
-
- - name: Install dependencies
- run: |
- sudo apt-get update --fix-missing
- sudo apt-get install -y cmake pandoc libsndfile1
- pip --version
- pip install '.[all,test]' --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
- pip install --requirement requirements/docs.txt
- pip list
- shell: bash
-
- - name: Cache datasets
- uses: actions/cache@v3
- with:
- path: |
- docs/data
- data
- key: flash-datasets-docs
-
- - name: Test Documentation
- working-directory: docs
- env:
- SPHINX_MOCK_REQUIREMENTS: 0
- FIFTYONE_DO_NOT_TRACK: true
- FLASH_TESTING: 1
- run: make doctest
+ - uses: actions/checkout@v3
+ with:
+ submodules: true
+ - uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+
+ - name: Cache pip
+ uses: actions/cache@v3
+ with:
+ path: ~/.cache/pip # this is specific for Ubuntu
+ key: pip-${{ hashFiles('requirements/*.txt') }}
+ restore-keys: pip-
+
+ - name: Install dependencies
+ run: |
+ sudo apt-get update --fix-missing
+ sudo apt-get install -y cmake pandoc libsndfile1
+ pip --version
+ pip install -e '.[all,test]' -r requirements/docs.txt -f $TORCH_URL
+ pip list
+
+ - name: Cache transformers
+ uses: actions/cache@v3
+ with:
+ path: ${{ env.TRANSFORMERS_CACHE }}
+ key: cache-transformers
+
+ - name: Test Documentation
+ working-directory: docs/
+ env:
+ SPHINX_MOCK_REQUIREMENTS: 0
+ FIFTYONE_DO_NOT_TRACK: true
+ FLASH_TESTING: 1
+ run: make doctest
diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml
index 5736e12f29..3e14e5cdea 100644
--- a/.github/workflows/pypi-release.yml
+++ b/.github/workflows/pypi-release.yml
@@ -20,11 +20,17 @@ jobs:
python-version: 3.8
- name: Install dependencies
- run: pip install --user --upgrade setuptools wheel build
- - name: Build
- run: |
- python -m build
- ls -lh dist/
+ run: python -m pip install -U setuptools wheel build twine
+ - name: Build package
+ run: python -m build
+ - name: Check package
+ run: twine check dist/*
+
+ - name: Upload to release
+ uses: AButler/upload-release-assets@v2.0
+ with:
+ files: 'dist/*'
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 549f5e8e42..73a69f2b04 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -36,7 +36,7 @@ repos:
- id: detect-private-key
- repo: https://github.com/asottile/pyupgrade
- rev: v3.3.2
+ rev: v3.8.0
hooks:
- id: pyupgrade
args: [--py38-plus]
@@ -48,13 +48,11 @@ repos:
- id: nbstripout
- repo: https://github.com/PyCQA/docformatter
- rev: v1.6.5
+ rev: v1.7.3
hooks:
- id: docformatter
- args:
- - "--in-place"
- - "--wrap-summaries=120"
- - "--wrap-descriptions=120"
+ additional_dependencies: [tomli]
+ args: ["--in-place"]
- repo: https://github.com/psf/black
rev: 23.3.0
@@ -62,22 +60,8 @@ repos:
- id: black
name: Format code
- - repo: https://github.com/PyCQA/isort
- rev: 5.12.0
- hooks:
- - id: isort
- name: imports
-
- - repo: https://github.com/asottile/blacken-docs
- rev: 1.13.0
- hooks:
- - id: blacken-docs
- args:
- - "--line-length=120"
- - "--skip-errors"
-
- - repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.0.264
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.0.276
hooks:
- id: ruff
args: ["--fix"]
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 33faf8b8bd..e1ca71bf9b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,36 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
+## [UnReleased] - 2023-MM-DD
+
+### Added
+
+
+
+### Changed
+
+-
+
+
+### Fixed
+
+
+
+
+## [0.8.2] - 2023-06-30
+
+### Changed
+
+- Added GATE backbone for Tabular integrations ([#1559](https://github.com/Lightning-AI/lightning-flash/pull/1559))
+
+### Fixed
+
+- Fixed datamodule can't load files with square brackets in names ([#1501](https://github.com/Lightning-AI/lightning-flash/pull/1501))
+- Fixed channel dim selection on segmentation target ([#1509](https://github.com/Lightning-AI/lightning-flash/pull/1509))
+- Fixed used of `jsonargparse` avoiding reliance on non-public internal logic ([#1620](https://github.com/Lightning-AI/lightning-flash/pull/1620))
+- Compatibility with `pytorch-tabular>=1.0` ([#1545](https://github.com/Lightning-AI/lightning-flash/pull/1545))
+- Compatibility latest `numpy` ([#1595](https://github.com/Lightning-AI/lightning-flash/pull/1595))
+
## [0.8.1] - 2022-11-08
### Added
@@ -15,6 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed compatibility with `lightning==1.8.0` ([#1479](https://github.com/Lightning-AI/lightning-flash/pull/1479))
- Fixed the error message to suggest installing `icevision`, if it's not found while loading data ([#1474](https://github.com/Lightning-AI/lightning-flash/pull/1474))
- Fixed compatibility with `torchmetrics==0.10.0` ([#1469](https://github.com/Lightning-AI/lightning-flash/pull/1469))
+- Fixed type of `n_gram` from bool to int in TranslationTask ([#1486](https://github.com/Lightning-AI/lightning-flash/pull/1486))
## [0.8.0] - 2022-09-02
diff --git a/README.md b/README.md
index da3188e95e..a83f7f6407 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@
Docs •
Contribute •
Community •
- Website •
+ Website •
License
@@ -23,8 +23,8 @@
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://www.pytorchlightning.ai/community)
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/pytorch-lightning/blob/master/LICENSE)
-![CI testing](https://github.com/Lightning-AI/lightning-flash/workflows/CI%20testing/badge.svg?branch=master&event=push)
-[![codecov](https://codecov.io/gh/Lightning-AI/lightning-flash/branch/master/graph/badge.svg?token=oLuUr9q1vt)](https://codecov.io/gh/Lightning-AI/lightning-flash)
+[![CI testing](https://github.com/Lightning-Universe/lightning-flash/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-Universe/lightning-flash/actions/workflows/ci-testing.yml)
+[![codecov](https://codecov.io/gh/Lightning-Universe/lightning-flash/branch/master/graph/badge.svg?token=oLuUr9q1vt)](https://codecov.io/gh/Lightning-Universe/lightning-flash)
[![Documentation Status](https://readthedocs.org/projects/lightning-flash/badge/?version=latest)](https://lightning-flash.readthedocs.io/en/stable/?badge=stable)
[![DOI](https://zenodo.org/badge/333857397.svg)](https://zenodo.org/badge/latestdoi/333857397)
diff --git a/examples/audio/audio_classification.py b/examples/audio/audio_classification.py
index 2fc70a21aa..fb3882d5a8 100644
--- a/examples/audio/audio_classification.py
+++ b/examples/audio/audio_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
from flash.image import ImageClassifier
diff --git a/examples/audio/speech_recognition.py b/examples/audio/speech_recognition.py
index 3b3be34586..ed096854f0 100644
--- a/examples/audio/speech_recognition.py
+++ b/examples/audio/speech_recognition.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data
diff --git a/examples/graph/graph_classification.py b/examples/graph/graph_classification.py
index 18212ba546..973ec9c62c 100644
--- a/examples/graph/graph_classification.py
+++ b/examples/graph/graph_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier
diff --git a/examples/graph/graph_embedder.py b/examples/graph/graph_embedder.py
index edacfb6754..917d260855 100644
--- a/examples/graph/graph_embedder.py
+++ b/examples/graph/graph_embedder.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphEmbedder
diff --git a/examples/image/baal_img_classification_active_learning.py b/examples/image/baal_img_classification_active_learning.py
index a3a97b7467..f8ce2e4711 100644
--- a/examples/image/baal_img_classification_active_learning.py
+++ b/examples/image/baal_img_classification_active_learning.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
diff --git a/examples/image/face_detection.py b/examples/image/face_detection.py
index e4ed1406a5..7cbd47a128 100644
--- a/examples/image/face_detection.py
+++ b/examples/image/face_detection.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.utilities.imports import example_requires
from flash.image import FaceDetectionData, FaceDetector
diff --git a/examples/image/fiftyone_img_classification.py b/examples/image/fiftyone_img_classification.py
index 458f59220d..4522f705cf 100644
--- a/examples/image/fiftyone_img_classification.py
+++ b/examples/image/fiftyone_img_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.core.integrations.fiftyone import visualize
from flash.image import ImageClassificationData, ImageClassifier
diff --git a/examples/image/fiftyone_img_classification_datasets.py b/examples/image/fiftyone_img_classification_datasets.py
index b4a4f474f4..987e5406c7 100644
--- a/examples/image/fiftyone_img_classification_datasets.py
+++ b/examples/image/fiftyone_img_classification_datasets.py
@@ -14,9 +14,8 @@
from itertools import chain
import fiftyone as fo
-import torch
-
import flash
+import torch
from flash.core.classification import FiftyOneLabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
diff --git a/examples/image/fiftyone_img_embedding.py b/examples/image/fiftyone_img_embedding.py
index 9b7382034d..c1ec3cce45 100644
--- a/examples/image/fiftyone_img_embedding.py
+++ b/examples/image/fiftyone_img_embedding.py
@@ -13,10 +13,9 @@
# limitations under the License.
import fiftyone as fo
import fiftyone.brain as fob
+import flash
import numpy as np
import torch
-
-import flash
from flash.core.data.utils import download_data
from flash.image import ImageEmbedder
from flash.image.classification.data import ImageClassificationData
diff --git a/examples/image/image_classification.py b/examples/image/image_classification.py
index 82ee7cbba0..f4496bbebe 100644
--- a/examples/image/image_classification.py
+++ b/examples/image/image_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
diff --git a/examples/image/image_classification_multi_label.py b/examples/image/image_classification_multi_label.py
index d3b7371c4a..b10db6f1b6 100644
--- a/examples/image/image_classification_multi_label.py
+++ b/examples/image/image_classification_multi_label.py
@@ -13,9 +13,8 @@
# limitations under the License.
import os
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
diff --git a/examples/image/image_embedder.py b/examples/image/image_embedder.py
index c30c718a53..561002176b 100644
--- a/examples/image/image_embedder.py
+++ b/examples/image/image_embedder.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-from torchvision.datasets import CIFAR10
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder
+from torchvision.datasets import CIFAR10
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
diff --git a/examples/image/learn2learn_img_classification_imagenette.py b/examples/image/learn2learn_img_classification_imagenette.py
index b4d8603ef1..3dc80760c0 100644
--- a/examples/image/learn2learn_img_classification_imagenette.py
+++ b/examples/image/learn2learn_img_classification_imagenette.py
@@ -27,13 +27,12 @@
from dataclasses import dataclass
from typing import Tuple, Union
+import flash
import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import torch
import torchvision.transforms as T
-
-import flash
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
diff --git a/examples/image/semantic_segmentation.py b/examples/image/semantic_segmentation.py
index 1b1a93cf93..cd511975a9 100644
--- a/examples/image/semantic_segmentation.py
+++ b/examples/image/semantic_segmentation.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
diff --git a/examples/image/style_transfer.py b/examples/image/style_transfer.py
index 066503f1f2..05675996ee 100644
--- a/examples/image/style_transfer.py
+++ b/examples/image/style_transfer.py
@@ -13,9 +13,8 @@
# limitations under the License.
import os
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData
diff --git a/examples/pointcloud/pcloud_detection.py b/examples/pointcloud/pcloud_detection.py
index fd707813f0..054a5f24af 100644
--- a/examples/pointcloud/pcloud_detection.py
+++ b/examples/pointcloud/pcloud_detection.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
diff --git a/examples/pointcloud/pcloud_segmentation.py b/examples/pointcloud/pcloud_segmentation.py
index a4e88bb703..b1d04a1d63 100644
--- a/examples/pointcloud/pcloud_segmentation.py
+++ b/examples/pointcloud/pcloud_segmentation.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
diff --git a/examples/pointcloud/visual_detection.py b/examples/pointcloud/visual_detection.py
index 9c3318960f..cff64e4c90 100644
--- a/examples/pointcloud/visual_detection.py
+++ b/examples/pointcloud/visual_detection.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData, launch_app
diff --git a/examples/pointcloud/visual_segmentation.py b/examples/pointcloud/visual_segmentation.py
index 8c4657f9f8..29e73f036a 100644
--- a/examples/pointcloud/visual_segmentation.py
+++ b/examples/pointcloud/visual_segmentation.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData, launch_app
diff --git a/examples/serve/generic/boston_prediction/inference_server.py b/examples/serve/generic/boston_prediction/inference_server.py
index 995ec3917f..32bb853b91 100644
--- a/examples/serve/generic/boston_prediction/inference_server.py
+++ b/examples/serve/generic/boston_prediction/inference_server.py
@@ -13,7 +13,6 @@
# limitations under the License.
import hummingbird.ml
import sklearn.datasets
-
from flash.core.serve import Composition, ModelComponent, expose
from flash.core.serve.types import Number, Table
diff --git a/examples/serve/generic/detection/inference.py b/examples/serve/generic/detection/inference.py
index 2ae25affa1..8674ab8301 100644
--- a/examples/serve/generic/detection/inference.py
+++ b/examples/serve/generic/detection/inference.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torchvision
-
from flash.core.serve import Composition, ModelComponent, expose
from flash.core.serve.types import BBox, Image, Label, Repeated
diff --git a/examples/serve/image_classification/client.py b/examples/serve/image_classification/client.py
index 77d9e89e7b..9fed1f1c30 100644
--- a/examples/serve/image_classification/client.py
+++ b/examples/serve/image_classification/client.py
@@ -14,9 +14,8 @@
import base64
from pathlib import Path
-import requests
-
import flash
+import requests
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
diff --git a/examples/serve/object_detection/client.py b/examples/serve/object_detection/client.py
index 77d9e89e7b..9fed1f1c30 100644
--- a/examples/serve/object_detection/client.py
+++ b/examples/serve/object_detection/client.py
@@ -14,9 +14,8 @@
import base64
from pathlib import Path
-import requests
-
import flash
+import requests
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
diff --git a/examples/serve/semantic_segmentation/client.py b/examples/serve/semantic_segmentation/client.py
index bad730e6a5..35b69d28a7 100644
--- a/examples/serve/semantic_segmentation/client.py
+++ b/examples/serve/semantic_segmentation/client.py
@@ -14,9 +14,8 @@
import base64
from pathlib import Path
-import requests
-
import flash
+import requests
with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
diff --git a/examples/serve/speech_recognition/client.py b/examples/serve/speech_recognition/client.py
index c855a37204..f610fd41a6 100644
--- a/examples/serve/speech_recognition/client.py
+++ b/examples/serve/speech_recognition/client.py
@@ -14,9 +14,8 @@
import base64
from pathlib import Path
-import requests
-
import flash
+import requests
with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
audio_str = base64.b64encode(f.read()).decode("UTF-8")
diff --git a/examples/serve/tabular_classification/client.py b/examples/serve/tabular_classification/client.py
index 4e6506b554..ce38beec59 100644
--- a/examples/serve/tabular_classification/client.py
+++ b/examples/serve/tabular_classification/client.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pandas as pd
import requests
-
from flash.core.data.utils import download_data
# 1. Download the data
diff --git a/examples/tabular/forecasting_interpretable.py b/examples/tabular/forecasting_interpretable.py
index 10ecc6b305..08f33728af 100644
--- a/examples/tabular/forecasting_interpretable.py
+++ b/examples/tabular/forecasting_interpretable.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.integrations.pytorch_forecasting import convert_predictions
from flash.core.utilities.imports import example_requires
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData
diff --git a/examples/tabular/tabular_classification.py b/examples/tabular/tabular_classification.py
index 244eb88455..7ef36bd6c6 100644
--- a/examples/tabular/tabular_classification.py
+++ b/examples/tabular/tabular_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.tabular import TabularClassificationData, TabularClassifier
# 1. Create the DataModule
diff --git a/examples/tabular/tabular_forecasting.py b/examples/tabular/tabular_forecasting.py
index fe21b7a469..ba9aca94bb 100644
--- a/examples/tabular/tabular_forecasting.py
+++ b/examples/tabular/tabular_forecasting.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.utilities.imports import example_requires
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData
diff --git a/examples/tabular/tabular_regression.py b/examples/tabular/tabular_regression.py
index a6c77b551b..d6b0a72186 100644
--- a/examples/tabular/tabular_regression.py
+++ b/examples/tabular/tabular_regression.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.tabular import TabularRegressionData, TabularRegressor
# 1. Create the DataModule
diff --git a/examples/template.py b/examples/template.py
index 2af11ecf60..be8964982b 100644
--- a/examples/template.py
+++ b/examples/template.py
@@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import flash
import numpy as np
import torch
-from sklearn import datasets
-
-import flash
from flash.template import TemplateData, TemplateSKLearnClassifier
+from sklearn import datasets
# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
diff --git a/examples/text/text_classification.py b/examples/text/text_classification.py
index 303be0e3b0..912dd8312f 100644
--- a/examples/text/text_classification.py
+++ b/examples/text/text_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
diff --git a/examples/text/text_classification_multi_label.py b/examples/text/text_classification_multi_label.py
index d5dce1e4f9..b82d5ea633 100644
--- a/examples/text/text_classification_multi_label.py
+++ b/examples/text/text_classification_multi_label.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
diff --git a/examples/text/text_embedder.py b/examples/text/text_embedder.py
index c5d6a010aa..e1897394fb 100644
--- a/examples/text/text_embedder.py
+++ b/examples/text/text_embedder.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.text import TextClassificationData, TextEmbedder
# 1. Create the DataModule
diff --git a/examples/text/translation.py b/examples/text/translation.py
index 0bee29eca5..581f2a15c7 100644
--- a/examples/text/translation.py
+++ b/examples/text/translation.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask
diff --git a/examples/video/video_classification.py b/examples/video/video_classification.py
index fc66c8ab00..e2991f9770 100644
--- a/examples/video/video_classification.py
+++ b/examples/video/video_classification.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
-
import flash
+import torch
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
diff --git a/pyproject.toml b/pyproject.toml
index 94d4a563d1..254cc502b9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,14 +27,13 @@ norecursedirs = [
]
addopts = [
"--strict-markers",
- "--doctest-modules",
"--color=yes",
"--disable-pytest-warnings",
]
#filterwarnings = [
# "error::FutureWarning",
#]
-xfail_strict = false # todo
+xfail_strict = true
junit_duration_report = "call"
[tool.coverage.report]
@@ -49,15 +48,11 @@ exclude_lines = [
line-length = 120
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"
-[tool.isort]
-known_first_party = [
- "flash",
- "examples",
- "tests",
-]
-skip_glob = []
-profile = "black"
-line_length = 120
+[tool.docformatter]
+recursive = true
+wrap-summaries = 120
+wrap-descriptions = 120
+blank = true
[tool.ruff]
@@ -66,6 +61,7 @@ line-length = 120
select = [
"E", "W", # see: https://pypi.org/project/pycodestyle
"F", # see: https://pypi.org/project/pyflakes
+ "I", # see: isort
# "D", # see: https://pypi.org/project/pydocstyle
# "N", # see: https://pypi.org/project/pep8-naming
]
diff --git a/requirements.txt b/requirements.txt
index d699c72e91..38a362e81a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1 @@
-# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-
-packaging <23.0
-setuptools <=59.5.0 # Prevent install bug with tensorboard
-numpy <1.24 # strict - freeze for using np.long
-torch >1.7.0
-torchmetrics >0.7.0, <0.11.0 # strict
-pytorch-lightning >1.6.0, <1.9.0 # strict
-pyDeprecate >0.1.0
-pandas >1.1.0, <=1.5.2
-jsonargparse[signatures] >4.0.0, <=4.9.0
-click >=7.1.2, <=8.1.3
-protobuf <=3.20.1
-fsspec[http] >=2022.5.0,<=2022.7.1
-lightning-utilities >=0.4.1
+-r ./requirements/base.txt
diff --git a/requirements/base.txt b/requirements/base.txt
new file mode 100644
index 0000000000..bb953b602f
--- /dev/null
+++ b/requirements/base.txt
@@ -0,0 +1,15 @@
+# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
+
+packaging <24.0
+setuptools <=68.0.0 # Prevent install bug with tensorboard
+numpy <1.26
+torch >1.7.0, <=2.0.1
+torchmetrics >0.7.0, <0.11.0 # strict
+pytorch-lightning >1.8.0, <2.0.0 # strict
+pyDeprecate >0.2.0
+pandas >1.1.0, <=2.0.3
+jsonargparse[signatures] >=4.22.0, <4.23.0
+click >=7.1.2, <=8.1.6
+protobuf <=3.20.1
+fsspec[http] >=2022.5.0,<=2023.6.0
+lightning-utilities >=0.4.1
diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt
index c28ac77842..d2d51f0cce 100644
--- a/requirements/datatype_audio.txt
+++ b/requirements/datatype_audio.txt
@@ -1,7 +1,8 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-torchaudio <=0.13.1
-torchvision <=0.14.1
-librosa >=0.8.1, <=0.9.2
-transformers >=4.13.0, <=4.25.1
-datasets >=1.16.1, <=2.8.0
+numpy <1.26
+torchaudio <=2.0.2
+torchvision <=0.15.2
+librosa >=0.8.1, <=0.10.0.post2
+transformers >=4.13.0, <=4.30.2
+datasets >1.16.1, <=2.13.1
diff --git a/requirements/datatype_graph.txt b/requirements/datatype_graph.txt
index fe223f68c8..72c2abdae0 100644
--- a/requirements/datatype_graph.txt
+++ b/requirements/datatype_graph.txt
@@ -1,8 +1,12 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-torch-scatter <=2.1.0
-torch-sparse <=0.6.16
-torch-geometric >=2.0.0, <=2.2.0
-torch-cluster <=1.6.0
-networkx <=2.8.8
-class-resolver >=0.3.2, <=0.3.10
+torch-scatter <=2.1.1
+torch-sparse <=0.6.17
+torch-geometric >=2.0.0, <=2.3.1
+torch-cluster <=1.6.1
+networkx <=3.1
+class-resolver >=0.3.2, <=0.4.2
+
+# todo: check if we can bump this versions, ALSO if bumped you need to update CI for find links
+torch ==1.13.1
+torchvision ==0.14.1
diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt
index faf62371df..3d19f0cbf3 100644
--- a/requirements/datatype_image.txt
+++ b/requirements/datatype_image.txt
@@ -1,14 +1,14 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-torchvision <=0.14.1
-timm >0.4.5, <=0.6.11 # effdet 0.3.0 depends on timm>=0.4.12
-lightning-bolts >0.3.3, <=0.6.0
-Pillow >7.1, <=9.3.0
-albumentations <=1.3.0
+torchvision <=0.15.2
+timm >0.4.5, <=0.9.2 # effdet 0.3.0 depends on timm>=0.4.12
+lightning-bolts >=0.7.0, <0.8.0
+Pillow >8.0, <=10.0.0
+albumentations >1.0.0, <=1.3.1
pystiche >1.0.0, <=1.0.1
-ftfy <=6.1.1
-regex <=2022.10.31
+ftfy >6.0.0, <=6.1.1
+regex <=2023.6.3
sahi >=0.8.19, <0.11 # strict - Fixes compatibility with icevision
-icevision >0.8
-icedata <=0.5.1 # dead
+icevision >0.8, <0.13.0
+icedata >0.5.0, <=0.5.1 # dead
diff --git a/requirements/datatype_image_baal.txt b/requirements/datatype_image_baal.txt
index 05f17d6913..a02882e9af 100644
--- a/requirements/datatype_image_baal.txt
+++ b/requirements/datatype_image_baal.txt
@@ -1,4 +1,4 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
# This is a separate file, as baal integration is affected by vissl installation (conflicts)
-baal >=1.3.2, <=1.7.0
+baal >=1.3.2, <=1.8.0
diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt
index 11a8f19a8e..a31a6bdf69 100644
--- a/requirements/datatype_image_extras.txt
+++ b/requirements/datatype_image_extras.txt
@@ -1,17 +1,14 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup
-matplotlib <=3.6.2
-fiftyone <0.19.0
-classy-vision <=0.6
-effdet <=0.3.0
-kornia >0.5.1, <=0.6.9
+matplotlib >3.0.0, <=3.7.2
+fiftyone <0.22.0
+classy-vision <=0.7.0
+effdet <=0.4.1
+kornia >0.5.1, <=0.6.12
learn2learn <=0.1.7; platform_system != "Windows" # dead
fastface <=0.1.3 # dead
fairscale
-# pin PL for testing, remove when fastface is updated
-pytorch-lightning <1.5.0
-
# pinned PL so we force a compatible TM version
torchmetrics<0.8.0
diff --git a/requirements/datatype_image_segm.txt b/requirements/datatype_image_segm.txt
index cf37ef2c0d..691d706a09 100644
--- a/requirements/datatype_image_segm.txt
+++ b/requirements/datatype_image_segm.txt
@@ -1,4 +1,4 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
# This is a separate file, as segmentation integration is affected by vissl installation (conflicts)
-segmentation-models-pytorch >0.2.0, <=0.3.1
+segmentation-models-pytorch >0.2.0, <=0.3.3
diff --git a/requirements/datatype_image_vissl.txt b/requirements/datatype_image_vissl.txt
index 1d196351d9..475f30fbfb 100644
--- a/requirements/datatype_image_vissl.txt
+++ b/requirements/datatype_image_vissl.txt
@@ -2,3 +2,11 @@
# This is a separate file, as vissl integration is affected by baal installation (conflicts)
vissl >=0.1.5, <=0.1.6 # dead
+
+# CI: lover bound is set just to limit the search space for pip/installation
+torch >1.10.0
+torchvision >0.11.0
+torchmetrics >0.10.0
+timm >0.9.0
+sahi >0.10.0
+icevision >0.11
diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt
index c3cc0d490f..8b3b4cfa16 100644
--- a/requirements/datatype_pointcloud.txt
+++ b/requirements/datatype_pointcloud.txt
@@ -3,4 +3,4 @@
open3d >=0.17.0, <0.18.0
# torch >=1.8.0, <1.9.0
# torchvision >0.9.0, <0.10.0
-tensorboard <=2.11.0
+tensorboard <=2.13.0
diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt
index fc9eeb7775..44baa2e54a 100644
--- a/requirements/datatype_tabular.txt
+++ b/requirements/datatype_tabular.txt
@@ -1,8 +1,7 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-scikit-learn <=1.2.0
+scikit-learn <=1.3.0
pytorch-forecasting >=0.10.0, <=0.10.3
-# pytorch-tabular >=1.0.2, <1.0.3 # pending requirements resolving
-pytorch-tabular @ https://github.com/manujosephv/pytorch_tabular/archive/refs/heads/main.zip
+pytorch-tabular >=1.0.2, <1.0.3
torchmetrics >=0.10.0
omegaconf <=2.1.1, <=2.1.1
diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt
index 786aa1f9d9..bcb51e7296 100644
--- a/requirements/datatype_text.txt
+++ b/requirements/datatype_text.txt
@@ -1,11 +1,11 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-torchvision <=0.14.1
-sentencepiece >=0.1.95, <=0.1.97
-filelock <=3.8.2
-transformers >4.13.0, <=4.25.1
-torchmetrics[text] >0.5.0, <0.11.0
-datasets >2.0.0, <=2.8.0
+torchvision <=0.15.2
+sentencepiece >=0.1.95, <=0.1.99
+filelock <=3.12.2
+transformers >=4.13.0, <=4.30.2
+torchmetrics[text] >0.5.0, <1.1.0
+datasets >=2.0.0, <=2.13.1
sentence-transformers <=2.2.2
ftfy <=6.1.1
-regex <=2022.10.31
+regex <=2023.6.3
diff --git a/requirements/datatype_video.txt b/requirements/datatype_video.txt
index 49c6ca8fdd..9af3ec213d 100644
--- a/requirements/datatype_video.txt
+++ b/requirements/datatype_video.txt
@@ -1,8 +1,8 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-torchvision <=0.14.1
-Pillow >7.1, <=9.3.0
-kornia >=0.5.1, <=0.6.9
+torchvision <=0.15.2
+Pillow >7.1, <=10.0.0
+kornia >=0.5.1, <=0.6.12
pytorchvideo ==0.1.2
-fiftyone <=0.18.0
+fiftyone <=0.21.4
diff --git a/requirements/serve.txt b/requirements/serve.txt
index e581827880..1f49801318 100644
--- a/requirements/serve.txt
+++ b/requirements/serve.txt
@@ -1,14 +1,14 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
-pillow >7.1, <=9.3.0
-pyyaml <=6.0
-cytoolz <=0.12.1
-graphviz <=0.20.1
-tqdm <=4.64.1
-fastapi >=0.65.2, <=0.68.2
-pydantic >1.8.1, <=1.10.2
-starlette ==0.14.2
-uvicorn[standard] >=0.12.0, <=0.20.0
-aiofiles <=22.1.0
-jinja2 >=3.0.0, <3.1.0
-torchvision <=0.14.1
+pillow >9.0.0, <=10.0.0
+pyyaml >5.4, <=6.0.1
+cytoolz >0.11, <=0.12.2
+graphviz >=0.19, <=0.20.1
+tqdm >4.60, <=4.65.0
+fastapi >0.65, <=0.100.0
+pydantic >1.8.1, <=2.0.3
+starlette <=0.30.0
+uvicorn[standard] >=0.12.0, <=0.23.2
+aiofiles >22.1.0, <=23.1.0
+jinja2 >=3.0.0, <3.2.0
+torchvision >0.10.0, <=0.15.2
diff --git a/requirements/test.txt b/requirements/test.txt
index d5bca03782..1621bd6233 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1,11 +1,8 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup
coverage[toml]
-pytest >6.2, <7.0
-pytest-doctestplus >0.12.0
-pytest-rerunfailures >11.0.0
+pytest ==7.4.0
+pytest-doctestplus ==0.13.0
+pytest-rerunfailures ==12.0
pytest-forked ==1.6.0
-pytest-mock ==3.10.0
-
-scikit-learn
-torch_optimizer
+pytest-mock ==3.11.1
diff --git a/requirements/testing_audio.txt b/requirements/testing_audio.txt
index 04365b9b5f..7a0fcb577b 100644
--- a/requirements/testing_audio.txt
+++ b/requirements/testing_audio.txt
@@ -1,10 +1,10 @@
-matplotlib
-torch ==1.11.0
-torchaudio ==0.11.0
-torchvision ==0.12.0
+matplotlib >3.0.0, <=3.7.2
+torch ==2.0.1
+torchaudio ==2.0.2
+torchvision ==0.15.2
-timm >0.4.5, <=0.6.11 # effdet 0.3.0 depends on timm>=0.4.12
-lightning-bolts >=0.3.3, <=0.6.0
-Pillow >7.1, <=9.3.0
-albumentations <=1.3.0
+timm >0.4.5, <=0.9.2 # effdet 0.3.0 depends on timm>=0.4.12
+lightning-bolts >=0.7.0, <0.8.0
+Pillow >8.0, <=10.0.0
+albumentations >1.0.0, <=1.3.1
pystiche >1.0.0, <=1.0.1
diff --git a/requirements/testing_graph.txt b/requirements/testing_graph.txt
index 59cd09b99d..be9d455a80 100644
--- a/requirements/testing_graph.txt
+++ b/requirements/testing_graph.txt
@@ -1,5 +1,6 @@
-torch ==1.11.0
-torchvision ==0.12.0
+# todo: try to bump this versions
+torch ==1.13.1
+torchvision ==0.14.1
-f https://download.pytorch.org/whl/cpu/torch_stable.html
--f https://data.pyg.org/whl/torch-1.11.0+cpu.html
+-f https://data.pyg.org/whl/torch-1.13.1+cpu.html
diff --git a/requirements/testing_serve.txt b/requirements/testing_serve.txt
index 95d1c9779a..99e4817627 100644
--- a/requirements/testing_serve.txt
+++ b/requirements/testing_serve.txt
@@ -1,7 +1,13 @@
-sahi ==0.8.19
-
+# source all main domains
-r datatype_image.txt
-r datatype_video.txt
-r datatype_tabular.txt
-r datatype_text.txt
-r datatype_audio.txt
+
+# CI: limit the search space for pip/installation
+sahi ==0.8.19
+torch ==1.10.2
+torchaudio ==0.10.2
+torchvision ==0.11.3
+torchmetrics ==0.10.3
diff --git a/setup.py b/setup.py
index 56074f92e5..d72d93b482 100644
--- a/setup.py
+++ b/setup.py
@@ -33,6 +33,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
>>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'...'
+
"""
path_readme = os.path.join(path_dir, "README.md")
text = open(path_readme, encoding="utf-8").read()
@@ -49,9 +50,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
# codecov badge
text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg")
# replace github badges for release ones
- text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}")
-
- return text
+ return text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}")
def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True) -> str:
@@ -67,6 +66,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze=True)
'arrow'
+
"""
# filer all comments
if comment_char in ln:
@@ -97,6 +97,7 @@ def _load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: boo
>>> path_req = os.path.join(_PATH_ROOT, "requirements")
>>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx>=4.0', ...]
+
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
@@ -142,8 +143,7 @@ def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict:
extras_req["all"] = _expand_reqs(extras_req, ["vision", "tabular", "text", "audio"])
extras_req["dev"] = _expand_reqs(extras_req, ["all", "test", "docs"])
# filter the uniques
- extras_req = {n: list(set(req)) for n, req in extras_req.items()}
- return extras_req
+ return {n: list(set(req)) for n, req in extras_req.items()}
# https://packaging.python.org/discussions/install-requires-vs-requirements /
@@ -151,52 +151,53 @@ def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict:
# what happens and to non-engineers they won't know to look in init ...
# the goal of the project is simplicity for researchers, don't want to add too much
# engineer specific practices
-setup(
- name="lightning-flash",
- version=about.__version__,
- description=about.__docs__,
- author=about.__author__,
- author_email=about.__author_email__,
- url=about.__homepage__,
- download_url="https://github.com/Lightning-AI/lightning-flash",
- license=about.__license__,
- package_dir={"": "src"},
- packages=find_packages(where="src"),
- long_description=_load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__),
- long_description_content_type="text/markdown",
- include_package_data=True,
- entry_points={
- "console_scripts": ["flash=flash.__main__:main"],
- },
- zip_safe=False,
- keywords=["deep learning", "pytorch", "AI"],
- python_requires=">=3.8",
- install_requires=_load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt"),
- extras_require=_get_extras(),
- project_urls={
- "Bug Tracker": "https://github.com/Lightning-AI/lightning-flash/issues",
- "Documentation": "https://lightning-flash.rtfd.io/en/latest/",
- "Source Code": "https://github.com/Lightning-AI/lightning-flash",
- },
- classifiers=[
- "Environment :: Console",
- "Natural Language :: English",
- # How mature is this project? Common values are
- # 3 - Alpha, 4 - Beta, 5 - Production/Stable
- "Development Status :: 4 - Beta",
- # Indicate who your project is intended for
- "Intended Audience :: Developers",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Scientific/Engineering :: Image Recognition",
- "Topic :: Scientific/Engineering :: Information Analysis",
- # Pick your license as you wish
- # 'License :: OSI Approved :: BSD License',
- "Operating System :: OS Independent",
- # Specify the Python versions you support here. In particular, ensure
- # that you indicate whether you support Python 2, Python 3 or both.
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- ],
-)
+if __name__ == "__main__":
+ setup(
+ name="lightning-flash",
+ version=about.__version__,
+ description=about.__docs__,
+ author=about.__author__,
+ author_email=about.__author_email__,
+ url=about.__homepage__,
+ download_url="https://github.com/Lightning-AI/lightning-flash",
+ license=about.__license__,
+ package_dir={"": "src"},
+ packages=find_packages(where="src"),
+ long_description=_load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__),
+ long_description_content_type="text/markdown",
+ include_package_data=True,
+ entry_points={
+ "console_scripts": ["flash=flash.__main__:main"],
+ },
+ zip_safe=False,
+ keywords=["deep learning", "pytorch", "AI"],
+ python_requires=">=3.8",
+ install_requires=_load_requirements(path_dir=_PATH_REQUIRE, file_name="base.txt"),
+ extras_require=_get_extras(),
+ project_urls={
+ "Bug Tracker": "https://github.com/Lightning-AI/lightning-flash/issues",
+ "Documentation": "https://lightning-flash.rtfd.io/en/latest/",
+ "Source Code": "https://github.com/Lightning-AI/lightning-flash",
+ },
+ classifiers=[
+ "Environment :: Console",
+ "Natural Language :: English",
+ # How mature is this project? Common values are
+ # 3 - Alpha, 4 - Beta, 5 - Production/Stable
+ "Development Status :: 4 - Beta",
+ # Indicate who your project is intended for
+ "Intended Audience :: Developers",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Image Recognition",
+ "Topic :: Scientific/Engineering :: Information Analysis",
+ # Pick your license as you wish
+ # 'License :: OSI Approved :: BSD License',
+ "Operating System :: OS Independent",
+ # Specify the Python versions you support here. In particular, ensure
+ # that you indicate whether you support Python 2, Python 3 or both.
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ ],
+ )
diff --git a/src/flash/__about__.py b/src/flash/__about__.py
index 01237fd7cb..f062db654f 100644
--- a/src/flash/__about__.py
+++ b/src/flash/__about__.py
@@ -1,8 +1,8 @@
-__version__ = "0.8.1.post0"
+__version__ = "0.8.2"
__author__ = "PyTorchLightning et al."
__author_email__ = "name@pytorchlightning.ai"
__license__ = "Apache-2.0"
-__copyright__ = f"Copyright (c) 2020-2022, {__author__}."
+__copyright__ = f"Copyright (c) 2020-2023, {__author__}."
__homepage__ = "https://github.com/Lightning-AI/lightning-flash"
__docs_url__ = "https://lightning-flash.readthedocs.io/en/stable/"
__docs__ = "Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes."
diff --git a/src/flash/__init__.py b/src/flash/__init__.py
index 25c88b2f8c..74aeac16c9 100644
--- a/src/flash/__init__.py
+++ b/src/flash/__init__.py
@@ -14,39 +14,43 @@
"""Root package info."""
import os
-from flash.__about__ import * # noqa: F401 F403
-from flash.core.utilities.imports import _TORCH_AVAILABLE
+import numpy
-if _TORCH_AVAILABLE:
- from flash.core.data.callback import FlashCallback
- from flash.core.data.data_module import DataModule
- from flash.core.data.io.input import DataKeys, Input
- from flash.core.data.io.input_transform import InputTransform
- from flash.core.data.io.output import Output
- from flash.core.data.io.output_transform import OutputTransform
- from flash.core.model import Task
- from flash.core.trainer import Trainer
- from flash.core.utilities.stages import RunningStage
+# adding compatibility for numpy >= 1.24
+for tp_name, tp_ins in [("object", object), ("bool", bool), ("float", float)]:
+ if not hasattr(numpy, tp_name):
+ setattr(numpy, tp_name, tp_ins)
- _PACKAGE_ROOT = os.path.dirname(__file__)
- ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets")
- PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
- _IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"
+from flash.__about__ import * # noqa: F401 E402 F403
+from flash.core.data.callback import FlashCallback # noqa: E402
+from flash.core.data.data_module import DataModule # noqa: E402
+from flash.core.data.io.input import DataKeys, Input # noqa: E402
+from flash.core.data.io.input_transform import InputTransform # noqa: E402
+from flash.core.data.io.output import Output # noqa: E402
+from flash.core.data.io.output_transform import OutputTransform # noqa: E402
+from flash.core.model import Task # noqa: E402
+from flash.core.trainer import Trainer # noqa: E402
+from flash.core.utilities.stages import RunningStage # noqa: E402
- if _IS_TESTING:
- from pytorch_lightning import seed_everything
+_PACKAGE_ROOT = os.path.dirname(__file__)
+ASSETS_ROOT = os.path.join(_PACKAGE_ROOT, "assets")
+PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
+_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1"
- seed_everything(42)
+if _IS_TESTING:
+ from pytorch_lightning import seed_everything
- __all__ = [
- "DataKeys",
- "DataModule",
- "FlashCallback",
- "Input",
- "InputTransform",
- "Output",
- "OutputTransform",
- "RunningStage",
- "Task",
- "Trainer",
- ]
+ seed_everything(42)
+
+__all__ = [
+ "DataKeys",
+ "DataModule",
+ "FlashCallback",
+ "Input",
+ "InputTransform",
+ "Output",
+ "OutputTransform",
+ "RunningStage",
+ "Task",
+ "Trainer",
+]
diff --git a/src/flash/audio/classification/data.py b/src/flash/audio/classification/data.py
index 951222b033..412713a9e0 100644
--- a/src/flash/audio/classification/data.py
+++ b/src/flash/audio/classification/data.py
@@ -32,18 +32,18 @@
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.data.utilities.paths import PATH_TYPE
-from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE
+from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.image.classification.data import MatplotlibVisualization
# Skip doctests if requirements aren't available
-if not _TOPIC_AUDIO_AVAILABLE:
+if not _TOPIC_AUDIO_AVAILABLE or not _TOPIC_IMAGE_AVAILABLE:
__doctest_skip__ = ["AudioClassificationData", "AudioClassificationData.*"]
class AudioClassificationData(DataModule):
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
- classmethods for loading data for audio classification."""
+ class methods for loading data for audio classification."""
input_transform_cls = AudioClassificationInputTransform
diff --git a/src/flash/audio/speech_recognition/collate.py b/src/flash/audio/speech_recognition/collate.py
index 0346bb04f4..90471bedf3 100644
--- a/src/flash/audio/speech_recognition/collate.py
+++ b/src/flash/audio/speech_recognition/collate.py
@@ -50,6 +50,7 @@ class DataCollatorCTCWithPadding:
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
+
"""
processor: AutoProcessor
diff --git a/src/flash/audio/speech_recognition/data.py b/src/flash/audio/speech_recognition/data.py
index b205afcd09..cc31f7df65 100644
--- a/src/flash/audio/speech_recognition/data.py
+++ b/src/flash/audio/speech_recognition/data.py
@@ -119,9 +119,7 @@ def from_files(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
"""
- ds_kw = {
- "sampling_rate": sampling_rate,
- }
+ ds_kw = {"sampling_rate": sampling_rate}
return cls(
input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw),
@@ -306,10 +304,7 @@ def from_csv(
>>> os.remove("predict_data.tsv")
"""
- ds_kw = {
- "input_key": input_field,
- "sampling_rate": sampling_rate,
- }
+ ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate}
return cls(
input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw),
@@ -430,11 +425,7 @@ def from_json(
>>> os.remove("predict_data.json")
"""
- ds_kw = {
- "input_key": input_field,
- "sampling_rate": sampling_rate,
- "field": field,
- }
+ ds_kw = {"input_key": input_field, "sampling_rate": sampling_rate, "field": field}
return cls(
input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw),
@@ -580,9 +571,7 @@ def from_datasets(
>>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)]
"""
- ds_kw = {
- "sampling_rate": sampling_rate,
- }
+ ds_kw = {"sampling_rate": sampling_rate}
return cls(
input_cls(RunningStage.TRAINING, train_dataset, **ds_kw),
diff --git a/src/flash/core/classification.py b/src/flash/core/classification.py
index f8c86fe844..394b824421 100644
--- a/src/flash/core/classification.py
+++ b/src/flash/core/classification.py
@@ -127,6 +127,7 @@ class ClassificationOutput(Output):
Args:
multi_label: If true, treats outputs as multi label logits.
+
"""
def __init__(self, multi_label: bool = False):
@@ -146,8 +147,7 @@ def multi_label(self) -> bool:
@CLASSIFICATION_OUTPUTS(name="preds")
class PredsClassificationOutput(ClassificationOutput):
"""A :class:`~flash.core.classification.ClassificationOutput` which gets the
- :attr:`~flash.core.data.io.input.InputFormat.PREDS` from the sample.
- """
+ :attr:`~flash.core.data.io.input.InputFormat.PREDS` from the sample."""
def transform(self, sample: Any) -> Any:
if isinstance(sample, Mapping) and DataKeys.PREDS in sample:
diff --git a/src/flash/core/data/batch.py b/src/flash/core/data/batch.py
index ba16ec429c..617e03a1f8 100644
--- a/src/flash/core/data/batch.py
+++ b/src/flash/core/data/batch.py
@@ -35,8 +35,7 @@ def forward(self, sample: str):
sample = self.serve_input._call_load_sample(sample)
if not isinstance(sample, list):
sample = [sample]
- sample = self.collate_fn(sample)
- return sample
+ return self.collate_fn(sample)
def _is_list_like_excluding_str(x):
@@ -59,6 +58,7 @@ def default_uncollate(batch: Any) -> List[Any]:
ValueError: If the input is a ``dict`` whose values are not all list-like.
ValueError: If the input is a ``dict`` whose values are not all the same length.
ValueError: If the input is not a ``dict`` or list-like.
+
"""
if isinstance(batch, dict):
if any(not _is_list_like_excluding_str(sub_batch) for sub_batch in batch.values()):
diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py
index 849ff061c3..9118300c96 100644
--- a/src/flash/core/data/data_module.py
+++ b/src/flash/core/data/data_module.py
@@ -103,7 +103,6 @@ class DataModule(pl.LightningDataModule):
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
-
"""
input_transform_cls = InputTransform
@@ -567,6 +566,7 @@ def _split_train_val(
Returns:
A tuple containing the training and validation datasets
+
"""
if not isinstance(val_split, float) or (isinstance(val_split, float) and val_split > 1 or val_split < 0):
diff --git a/src/flash/core/data/io/classification_input.py b/src/flash/core/data/io/classification_input.py
index 4c5cd60622..31daf8273c 100644
--- a/src/flash/core/data/io/classification_input.py
+++ b/src/flash/core/data/io/classification_input.py
@@ -64,5 +64,6 @@ def format_target(self, target: Any) -> Any:
Returns:
The formatted target.
+
"""
return getattr(self, "target_formatter", lambda x: x)(target)
diff --git a/src/flash/core/data/io/input.py b/src/flash/core/data/io/input.py
index 7e42c448f2..199aee6f00 100644
--- a/src/flash/core/data/io/input.py
+++ b/src/flash/core/data/io/input.py
@@ -89,6 +89,7 @@ def _has_len(data: Union[Sequence, Iterable]) -> bool:
Args:
data: The object to check for length support.
+
"""
try:
len(data)
diff --git a/src/flash/core/data/io/input_transform.py b/src/flash/core/data/io/input_transform.py
index f65af3d78c..d436ac681d 100644
--- a/src/flash/core/data/io/input_transform.py
+++ b/src/flash/core/data/io/input_transform.py
@@ -84,6 +84,7 @@ def per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
pass
@@ -97,6 +98,7 @@ def train_per_sample_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_sample_transform()
@@ -121,6 +123,7 @@ def val_per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_sample_transform()
@@ -134,6 +137,7 @@ def test_per_sample_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_sample_transform()
@@ -158,6 +162,7 @@ def predict_per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_sample_transform()
@@ -182,6 +187,7 @@ def serve_per_sample_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_sample_transform()
@@ -210,6 +216,7 @@ def per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
pass
@@ -223,6 +230,7 @@ def train_per_sample_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_sample_transform_on_device()
@@ -247,6 +255,7 @@ def val_per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_sample_transform_on_device()
@@ -260,6 +269,7 @@ def test_per_sample_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_sample_transform_on_device()
@@ -284,6 +294,7 @@ def predict_per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_sample_transform_on_device()
@@ -308,6 +319,7 @@ def serve_per_sample_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def serve_per_sample_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_sample_transform_on_device()
@@ -336,6 +348,7 @@ def per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
pass
@@ -349,6 +362,7 @@ def train_per_batch_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_batch_transform()
@@ -373,6 +387,7 @@ def val_per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_batch_transform()
@@ -386,6 +401,7 @@ def test_per_batch_transform(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_batch_transform()
@@ -410,6 +426,7 @@ def predict_per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_batch_transform()
@@ -434,6 +451,7 @@ def serve_per_batch_transform(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_batch_transform()
@@ -462,6 +480,7 @@ def per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
pass
@@ -475,6 +494,7 @@ def train_per_batch_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_batch_transform_on_device()
@@ -499,6 +519,7 @@ def val_per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_batch_transform_on_device()
@@ -512,6 +533,7 @@ def test_per_batch_transform_on_device(self) -> Callable:
DataKeys.TARGET: ...,
DataKeys.METADATA: ...,
}
+
"""
return self.per_batch_transform_on_device()
@@ -536,6 +558,7 @@ def predict_per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_batch_transform_on_device()
@@ -560,6 +583,7 @@ def serve_per_batch_transform_on_device(self) -> Callable:
class MyInputTransform(InputTransform):
def serve_per_batch_transform_on_device(self) -> Callable:
return ApplyToKeys("input", my_func)
+
"""
return self.per_batch_transform_on_device()
@@ -606,6 +630,7 @@ def _per_batch_transform(self, batch: Any, stage: RunningStage) -> Any:
.. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are
specified, uncollation has to be applied.
+
"""
return self.current_transform(stage=stage, current_fn="per_batch_transform")(batch)
@@ -620,6 +645,7 @@ def _per_sample_transform_on_device(self, sample: Any, stage: RunningStage) -> A
specified, uncollation has to be applied. .. note:: This function won't be called within the dataloader
workers, since to make that happen each of the workers would have to create it's own CUDA-context which
would pollute GPU memory (if on GPU).
+
"""
fn = self.current_transform(stage=stage, current_fn="per_sample_transform_on_device")
if isinstance(sample, list):
@@ -631,6 +657,7 @@ def _per_batch_transform_on_device(self, batch: Any, stage: RunningStage) -> Any
.. note:: This function won't be called within the dataloader workers, since to make that happen each of
the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU).
+
"""
return self.current_transform(stage=stage, current_fn="per_batch_transform_on_device")(batch)
diff --git a/src/flash/core/data/io/output.py b/src/flash/core/data/io/output.py
index 0b8e7467a8..e0765ca754 100644
--- a/src/flash/core/data/io/output.py
+++ b/src/flash/core/data/io/output.py
@@ -37,6 +37,7 @@ def transform(sample: Any) -> Any:
Returns:
The converted output.
+
"""
return sample
diff --git a/src/flash/core/data/io/output_transform.py b/src/flash/core/data/io/output_transform.py
index 0e691ce51a..c0e799542e 100644
--- a/src/flash/core/data/io/output_transform.py
+++ b/src/flash/core/data/io/output_transform.py
@@ -25,6 +25,7 @@ def per_batch_transform(batch: Any) -> Any:
"""Transforms to apply on a whole batch before uncollation to individual samples.
Can involve both CPU and Device transforms as this is not applied in separate workers.
+
"""
return batch
@@ -33,6 +34,7 @@ def per_sample_transform(sample: Any) -> Any:
"""Transforms to apply to a single sample after splitting up the batch.
Can involve both CPU and Device transforms as this is not applied in separate workers.
+
"""
return sample
@@ -41,6 +43,7 @@ def uncollate(batch: Any) -> Any:
"""Uncollates a batch into single samples.
Tries to preserve the type wherever possible.
+
"""
return default_uncollate(batch)
diff --git a/src/flash/core/data/splits.py b/src/flash/core/data/splits.py
index a51e29fade..b5b8492fef 100644
--- a/src/flash/core/data/splits.py
+++ b/src/flash/core/data/splits.py
@@ -21,6 +21,7 @@ class SplitDataset(Properties, Dataset):
split_ds = SplitDataset(dataset, indices=[10, 14, 25])
split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True)
+
"""
def __init__(
diff --git a/src/flash/core/data/utilities/data_frame.py b/src/flash/core/data/utilities/data_frame.py
index d2f4d7fc8f..ab4f99015b 100644
--- a/src/flash/core/data/utilities/data_frame.py
+++ b/src/flash/core/data/utilities/data_frame.py
@@ -30,6 +30,7 @@ def resolve_targets(data_frame: pd.DataFrame, target_keys: Union[str, List[str]]
Args:
data_frame: The ``pd.DataFrame`` containing the target column / columns.
target_keys: The column in the data frame (or a list of columns) from which to resolve the target.
+
"""
if not isinstance(target_keys, List):
return data_frame[target_keys].tolist()
@@ -63,6 +64,7 @@ def resolve_files(
root: The root path to use when resolving files.
resolver: The resolver function to use. This function should receive the root and a file ID as input and return
the path to an existing file.
+
"""
if resolver is None:
resolver = default_resolver
diff --git a/src/flash/core/data/utilities/loading.py b/src/flash/core/data/utilities/loading.py
index c0675f0278..cc5e5067db 100644
--- a/src/flash/core/data/utilities/loading.py
+++ b/src/flash/core/data/utilities/loading.py
@@ -68,8 +68,7 @@ def _load_image_from_image(file):
img = Image.open(file)
img.load()
- img = img.convert("RGB")
- return img
+ return img.convert("RGB")
def _load_image_from_numpy(file):
@@ -182,6 +181,7 @@ def load_image(file_path: str):
Args:
file_path: The image file to load.
+
"""
return load(file_path, _image_loaders)
@@ -193,6 +193,7 @@ def load_spectrogram(file_path: str, sampling_rate: int = 16000, n_fft: int = 40
file_path: The file to load.
sampling_rate: The sampling rate to resample to if loading from an audio file.
n_fft: The size of the FFT to use when creating a spectrogram from an audio file.
+
"""
loaders = copy.copy(_spectrogram_loaders)
loaders[AUDIO_EXTENSIONS] = partial(loaders[AUDIO_EXTENSIONS], sampling_rate=sampling_rate, n_fft=n_fft)
@@ -205,6 +206,7 @@ def load_audio(file_path: str, sampling_rate: int = 16000):
Args:
file_path: The file to load.
sampling_rate: The sampling rate to resample to.
+
"""
loaders = {
extensions: partial(loader, sampling_rate=sampling_rate) for extensions, loader in _audio_loaders.items()
@@ -218,6 +220,7 @@ def load_data_frame(file_path: str, encoding: str = "utf-8"):
Args:
file_path: The file to load.
encoding: The encoding to use when reading the file.
+
"""
loaders = {extensions: partial(loader, encoding=encoding) for extensions, loader in _data_frame_loaders.items()}
return load(file_path, loaders)
diff --git a/src/flash/core/data/utilities/paths.py b/src/flash/core/data/utilities/paths.py
index 7d8850070e..96939e0e15 100644
--- a/src/flash/core/data/utilities/paths.py
+++ b/src/flash/core/data/utilities/paths.py
@@ -32,6 +32,7 @@ def has_file_allowed_extension(filename: PATH_TYPE, extensions: Tuple[str, ...])
Returns:
bool: True if the filename ends with one of given extensions
+
"""
return str(filename).lower().endswith(extensions)
@@ -59,6 +60,7 @@ def make_dataset(
Returns:
(files, targets) Tuple containing the list of files and corresponding list of targets.
+
"""
files, targets = [], []
directory = os.path.expanduser(str(directory))
@@ -104,6 +106,7 @@ def list_subdirs(folder: PATH_TYPE) -> List[str]:
Returns:
The list of subdirectories.
+
"""
return list(sorted_alphanumeric(d.name for d in os.scandir(str(folder)) if d.is_dir()))
@@ -146,6 +149,7 @@ def filter_valid_files(
Returns:
The filtered lists.
+
"""
if not isinstance(files, List):
files = [files]
diff --git a/src/flash/core/data/utilities/samples.py b/src/flash/core/data/utilities/samples.py
index 70a2bdf8db..8f26462b1d 100644
--- a/src/flash/core/data/utilities/samples.py
+++ b/src/flash/core/data/utilities/samples.py
@@ -32,6 +32,7 @@ def to_sample(input: Any) -> Dict[str, Any]:
Returns:
A sample dictionary.
+
"""
if isinstance(input, dict) and DataKeys.INPUT in input:
return input
@@ -51,6 +52,7 @@ def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[D
Returns:
A list of sample dictionaries.
+
"""
if targets is None:
return [to_sample(input) for input in inputs]
diff --git a/src/flash/core/data/utils.py b/src/flash/core/data/utils.py
index fb85435f45..e142615d74 100644
--- a/src/flash/core/data/utils.py
+++ b/src/flash/core/data/utils.py
@@ -76,6 +76,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
>>> download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
>>> os.listdir("./data") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[...]
+
"""
# Disable warning about making an insecure request
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -144,5 +145,4 @@ def convert_to_modules(transforms: Optional[Dict[str, Callable]]):
transforms = apply_to_collection(transforms, Callable, FuncModule, wrong_dtype=nn.Module)
transforms = apply_to_collection(transforms, Mapping, nn.ModuleDict, wrong_dtype=nn.ModuleDict)
- transforms = apply_to_collection(transforms, Iterable, nn.ModuleList, wrong_dtype=(nn.ModuleList, nn.ModuleDict))
- return transforms
+ return apply_to_collection(transforms, Iterable, nn.ModuleList, wrong_dtype=(nn.ModuleList, nn.ModuleDict))
diff --git a/src/flash/core/finetuning.py b/src/flash/core/finetuning.py
index 71b684de4b..be6dfb99cb 100644
--- a/src/flash/core/finetuning.py
+++ b/src/flash/core/finetuning.py
@@ -221,6 +221,7 @@ class FlashDeepSpeedFinetuning(FlashBaseFinetuning):
DeepSpeed cannot store and load its parameters when working with Lightning. So FlashDeepSpeedFinetuning overrides
`_store` to not store its parameters.
+
"""
def _store(
diff --git a/src/flash/core/heads.py b/src/flash/core/heads.py
index a4a5416633..160d9b2a2f 100644
--- a/src/flash/core/heads.py
+++ b/src/flash/core/heads.py
@@ -30,6 +30,7 @@ def _load_linear_head(num_features: int, num_classes: int) -> nn.Module:
Returns:
nn.Module: Linear head.
+
"""
return nn.Linear(num_features, num_classes)
diff --git a/src/flash/core/integrations/fiftyone/utils.py b/src/flash/core/integrations/fiftyone/utils.py
index 62a5de5094..8a14dd21a5 100644
--- a/src/flash/core/integrations/fiftyone/utils.py
+++ b/src/flash/core/integrations/fiftyone/utils.py
@@ -26,8 +26,7 @@ def visualize(
wait: Optional[bool] = False,
**kwargs,
) -> Optional[Session]:
- """Visualizes predictions from a model with a FiftyOne Output in the
- :ref:`FiftyOne App `.
+ """Visualizes predictions from a model with a FiftyOne Output in the :ref:`FiftyOne App `.
This method can be used in all of the following environments:
@@ -61,6 +60,7 @@ def visualize(
Returns:
a :class:`fiftyone:fiftyone.core.session.Session`
+
"""
if flash._IS_TESTING:
return None
diff --git a/src/flash/core/integrations/labelstudio/input.py b/src/flash/core/integrations/labelstudio/input.py
index a288241e96..ed2aa63bdb 100644
--- a/src/flash/core/integrations/labelstudio/input.py
+++ b/src/flash/core/integrations/labelstudio/input.py
@@ -310,7 +310,7 @@ def convert_to_encodedvideo(self, dataset):
if len(dataset) > 0:
from pytorchvideo.data import LabeledVideoDataset
- dataset = LabeledVideoDataset(
+ return LabeledVideoDataset(
[
(
os.path.join(self._data_folder, sample["file_upload"]),
@@ -322,7 +322,6 @@ def convert_to_encodedvideo(self, dataset):
decode_audio=self.decode_audio,
decoder=self.decoder,
)
- return dataset
return []
diff --git a/src/flash/core/integrations/pytorch_forecasting/adapter.py b/src/flash/core/integrations/pytorch_forecasting/adapter.py
index a9fc69bf68..d84277c121 100644
--- a/src/flash/core/integrations/pytorch_forecasting/adapter.py
+++ b/src/flash/core/integrations/pytorch_forecasting/adapter.py
@@ -39,6 +39,7 @@ class PatchTimeSeriesDataSet(TimeSeriesDataSet):
"""Hack to prevent index construction or data validation / conversion when instantiating model.
This enables the ``TimeSeriesDataSet`` to be created from a single row of data.
+
"""
def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame:
diff --git a/src/flash/core/integrations/pytorch_tabular/adapter.py b/src/flash/core/integrations/pytorch_tabular/adapter.py
index b9aca1f243..e68b68d35b 100644
--- a/src/flash/core/integrations/pytorch_tabular/adapter.py
+++ b/src/flash/core/integrations/pytorch_tabular/adapter.py
@@ -50,6 +50,7 @@ def from_task(
"categorical_dim": len(categorical_fields),
"continuous_dim": num_features - len(categorical_fields),
"output_dim": output_dim,
+ "embedded_cat_dim": sum([embd_dim for _, embd_dim in embedding_sizes]),
}
return cls(
task_type,
diff --git a/src/flash/core/integrations/pytorch_tabular/backbones.py b/src/flash/core/integrations/pytorch_tabular/backbones.py
index 72084ae0d8..89011a1f51 100644
--- a/src/flash/core/integrations/pytorch_tabular/backbones.py
+++ b/src/flash/core/integrations/pytorch_tabular/backbones.py
@@ -30,6 +30,7 @@
AutoIntConfig,
CategoryEmbeddingModelConfig,
FTTransformerConfig,
+ GatedAdditiveTreeEnsembleConfig,
NodeConfig,
TabNetModelConfig,
TabTransformerConfig,
@@ -56,8 +57,7 @@ def _read_parse_config(config, cls):
)
else:
raise ValueError(f"{config} is not a valid path")
- config = OmegaConf.structured(config)
- return config
+ return OmegaConf.structured(config)
def load_pytorch_tabular(
model_config_class,
@@ -88,8 +88,9 @@ def load_pytorch_tabular(
AutoIntConfig,
NodeConfig,
CategoryEmbeddingModelConfig,
+ GatedAdditiveTreeEnsembleConfig,
],
- ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"],
+ ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding", "gate"],
):
PYTORCH_TABULAR_BACKBONES(
functools.partial(load_pytorch_tabular, model_config_class),
diff --git a/src/flash/core/model.py b/src/flash/core/model.py
index 3bdd199f3b..df1c92399c 100644
--- a/src/flash/core/model.py
+++ b/src/flash/core/model.py
@@ -323,6 +323,14 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, FineTuningHooks
`metric(preds,target)` and return a single scalar tensor.
output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to use as the default for this
task.
+
+ >>> Task() # doctest: +ELLIPSIS
+ Task(
+ (train_metrics): ModuleDict()
+ (val_metrics): ModuleDict()
+ (test_metrics): ModuleDict()
+ )
+
"""
optimizers_registry: FlashRegistry = _OPTIMIZERS_REGISTRY
@@ -374,6 +382,7 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
Returns:
A dict containing both the loss and relevant metrics
+
"""
x, y = batch
y_hat = self(x)
diff --git a/src/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py
index d0b07e2615..3485ebcbb8 100644
--- a/src/flash/core/optimizers/lamb.py
+++ b/src/flash/core/optimizers/lamb.py
@@ -50,6 +50,7 @@ class LAMB(Optimizer):
(default: False)
Example:
+ >>> from torch import nn
>>> model = nn.Linear(10, 1)
>>> optimizer = LAMB(model.parameters(), lr=0.1)
>>> optimizer.zero_grad()
@@ -104,6 +105,7 @@ def step(self, closure=None):
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
+
"""
loss = None
if closure is not None:
diff --git a/src/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py
index f89f2cba0b..86e6058e14 100644
--- a/src/flash/core/optimizers/lars.py
+++ b/src/flash/core/optimizers/lars.py
@@ -45,6 +45,7 @@ class LARS(Optimizer):
eps (float, optional): eps for division denominator (default: 1e-8)
Example:
+ >>> from torch import nn
>>> model = nn.Linear(10, 1)
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
@@ -72,6 +73,7 @@ class LARS(Optimizer):
Parameters with weight decay set to 0 will automatically be excluded from
layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
and BYOL.
+
"""
def __init__(
@@ -120,6 +122,7 @@ def step(self, closure=None):
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
+
"""
loss = None
if closure is not None:
diff --git a/src/flash/core/optimizers/lr_scheduler.py b/src/flash/core/optimizers/lr_scheduler.py
index e0e918ca9f..1ecb59a2f9 100644
--- a/src/flash/core/optimizers/lr_scheduler.py
+++ b/src/flash/core/optimizers/lr_scheduler.py
@@ -46,6 +46,8 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler):
train and validation methods.
Example:
+ >>> from torch import nn
+ >>> from torch.optim import Adam
>>> layer = nn.Linear(10, 1)
>>> optimizer = Adam(layer.parameters(), lr=0.02)
>>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
@@ -61,6 +63,7 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler):
... scheduler.step(epoch)
... # train(...)
... # validate(...)
+
"""
def __init__(
diff --git a/src/flash/core/registry.py b/src/flash/core/registry.py
index b968ee1934..924cce8219 100644
--- a/src/flash/core/registry.py
+++ b/src/flash/core/registry.py
@@ -158,6 +158,7 @@ def __call__(
"""This function is used to register new functions to the registry along their metadata.
Functions can be filtered using metadata using the ``get`` function.
+
"""
if providers is not None:
metadata["providers"] = providers
diff --git a/src/flash/core/serve/_compat/cached_property.py b/src/flash/core/serve/_compat/cached_property.py
index 50327f8d3f..29b9592cb4 100644
--- a/src/flash/core/serve/_compat/cached_property.py
+++ b/src/flash/core/serve/_compat/cached_property.py
@@ -3,6 +3,7 @@
cached_property() - computed once per instance, cached as attribute
credits: https://github.com/penguinolog/backports.cached_property
+
"""
__all__ = ("cached_property",)
diff --git a/src/flash/core/serve/component.py b/src/flash/core/serve/component.py
index c267ea49b8..8ee20b540e 100644
--- a/src/flash/core/serve/component.py
+++ b/src/flash/core/serve/component.py
@@ -55,6 +55,7 @@ class to perform the analysis on
------
SyntaxError
If parameters are not specified correctly.
+
"""
params = inspect.signature(cls.__init__).parameters
if len(params) > 3:
@@ -89,6 +90,7 @@ def _validate_model_args(
If an empty iterable is passed as the model argument
TypeError
If the args do not contain properly formatted model refences
+
"""
if isiterable(args) and len(args) == 0:
raise ValueError(f"Iterable args={args} must have length >= 1")
@@ -122,6 +124,7 @@ def _validate_config_args(config: Optional[Dict[str, Union[str, int, float, byte
If ``config`` is a dict with invalid key/values
ValueError
If ``config`` is a dict with 0 arguments
+
"""
if config is None:
return
@@ -183,6 +186,7 @@ def __call__(cls, *args, **kwargs):
super().__call__() within metaclass means: return instance created by calling metaclass __prepare__ -> __new__
-> __init__
+
"""
klass = super().__call__(*args, **kwargs)
klass._flashserve_meta_ = replace(klass._flashserve_meta_)
@@ -203,6 +207,7 @@ class ModelComponent(metaclass=FlashServeMeta):
assets, etc. The specification must be YAML serializable and loadable to/from a fully initialized instance. It
must contain the minimal set of information necessary to find and initialize its dependencies (assets) and
itself.
+
"""
_flashserve_meta_: Optional[Union[BoundMeta, UnboundMeta]] = None
@@ -211,6 +216,7 @@ def __flashserve_init__(self, models, *, config=None):
"""Do a bunch of setup.
instance's __flashserve_init__ calls subclass __init__ in turn.
+
"""
_validate_model_args(models)
_validate_config_args(config)
diff --git a/src/flash/core/serve/composition.py b/src/flash/core/serve/composition.py
index d627c02995..616464bb9a 100644
--- a/src/flash/core/serve/composition.py
+++ b/src/flash/core/serve/composition.py
@@ -63,6 +63,7 @@ class Composition(ServerMixin):
which provides introspection of components, endpoints, etc.
* We plan to add some user-facing API to the ``Composition`` object
which allows for modification of the composition.
+
"""
_uid_comps: Dict[str, ModelComponent]
diff --git a/src/flash/core/serve/core.py b/src/flash/core/serve/core.py
index c55be4641f..ce947b3122 100644
--- a/src/flash/core/serve/core.py
+++ b/src/flash/core/serve/core.py
@@ -32,6 +32,7 @@ class Endpoint:
outputs
The full name of a component output. Typically, specified by just passing
in the component parameter attribute (i.e.``component.outputs.bar``).
+
"""
route: str
@@ -99,6 +100,7 @@ class Servable:
----
* How to handle ``__init__`` args for ``torch.nn.Module``
* How to handle ``__init__`` args not recorded in hparams of ``pl.LightningModule``
+
"""
@requires("serve")
@@ -151,6 +153,7 @@ class Connection(NamedTuple):
* This data structure should not be instantiated directly! The
class_methods attached to the class are the indended mechanisms to create
a new instance.
+
"""
source_component: str
@@ -191,6 +194,7 @@ class Parameter:
Which component this type is associated with
position
Position in the while exposing it i.e `inputs` or `outputs`
+
"""
name: str
@@ -220,6 +224,7 @@ def __terminate_invalid_connection_request(self, other: "Parameter", dunder_meth
TypeError, RuntimeError
if the verification fails, we throw an exception to stop the
connection from being created.
+
"""
# assert this is actually a class object we can compare against.
if not isinstance(other, self.__class__) or (other.__class__ != self.__class__):
@@ -313,6 +318,7 @@ def make_parameter_container(data: Dict[str, Parameter]) -> ParameterContainer:
* parameter name must be valid python attribute (identifier) and
cannot be a builtin keyword. input names should have been validated
by this point.
+
"""
dataclass_fields = [(param_name, type(param)) for param_name, param in data.items()]
ParameterContainer = make_dataclass(
@@ -335,6 +341,7 @@ def make_param_dict(
Tuple[Dict[str, Parameter], Dict[str, Parameter]]
Element[0] == Input parameter dict
Element[1] == Output parameter dict.
+
"""
flashserve_inp_params, flashserve_out_params = {}, {}
for inp_key, inp_dtype in inputs.items():
diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py
index 4ab3a07ef2..5b8de1196a 100644
--- a/src/flash/core/serve/dag/optimization.py
+++ b/src/flash/core/serve/dag/optimization.py
@@ -19,18 +19,20 @@ def cull(dsk, keys):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> d = {'x': 1, 'y': (inc, 'x'), 'out': (add, 'x', 10)}
- >>> dsk, dependencies = cull(d, 'out') # doctest: +SKIP
- >>> dsk # doctest: +SKIP
- {'x': 1, 'out': (add, 'x', 10)}
- >>> dependencies # doctest: +SKIP
- {'x': set(), 'out': set(['x'])}
+ >>> dsk, dependencies = cull(d, 'out')
+ >>> dsk # doctest: +ELLIPSIS
+ {'out': (, 'x', 10), 'x': 1}
+ >>> dependencies
+ {'out': ['x'], 'x': []}
Returns
-------
dsk: culled graph
dependencies: Dict mapping {key: [deps]}. Useful side effect to accelerate
other optimizations, notably fuse.
+
"""
if not isinstance(keys, (list, set)):
keys = [keys]
@@ -95,22 +97,24 @@ def fuse_linear(dsk, keys=None, dependencies=None, rename_keys=True):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import inc
>>> d = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
>>> dsk, dependencies = fuse(d)
- >>> dsk # doctest: +SKIP
- {'a-b-c': (inc, (inc, 1)), 'c': 'a-b-c'}
+ >>> dsk # doctest: +ELLIPSIS
+ {'c': 'a-b-c', 'a-b-c': (, (, 1))}
>>> dsk, dependencies = fuse(d, rename_keys=False)
- >>> dsk # doctest: +SKIP
- {'c': (inc, (inc, 1))}
+ >>> dsk # doctest: +ELLIPSIS
+ {'c': (, (, 1))}
>>> dsk, dependencies = fuse(d, keys=['b'], rename_keys=False)
- >>> dsk # doctest: +SKIP
- {'b': (inc, 1), 'c': (inc, 'b')}
+ >>> dsk # doctest: +ELLIPSIS
+ {'b': (, 1), 'c': (, 'b')}
Returns
-------
dsk: output graph with keys fused
dependencies: dict mapping dependencies after fusion. Useful side effect
to accelerate other downstream optimizations.
+
"""
if keys is not None and not isinstance(keys, set):
if not isinstance(keys, list):
@@ -226,13 +230,15 @@ def inline(dsk, keys=None, inline_constants=True, dependencies=None):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> d = {'x': 1, 'y': (inc, 'x'), 'z': (add, 'x', 'y')}
- >>> inline(d) # doctest: +SKIP
- {'x': 1, 'y': (inc, 1), 'z': (add, 1, 'y')}
- >>> inline(d, keys='y') # doctest: +SKIP
- {'x': 1, 'y': (inc, 1), 'z': (add, 1, (inc, 1))}
- >>> inline(d, keys='y', inline_constants=False) # doctest: +SKIP
- {'x': 1, 'y': (inc, 1), 'z': (add, 'x', (inc, 'x'))}
+ >>> inline(d) # doctest: +ELLIPSIS
+ {'x': 1, 'y': (, 1), 'z': (, 1, 'y')}
+ >>> inline(d, keys='y') # doctest: +ELLIPSIS
+ {'x': 1, 'y': (, 1), 'z': (, 1, (, 1))}
+ >>> inline(d, keys='y', inline_constants=False) # doctest: +ELLIPSIS
+ {'x': 1, 'y': (, 'x'), 'z': (, 'x', (, 'x'))}
+
"""
if dependencies and isinstance(next(iter(dependencies.values())), list):
dependencies = {k: set(v) for k, v in dependencies.items()}
@@ -273,23 +279,27 @@ def inline_functions(dsk, output, fast_functions=None, inline_constants=False, d
Examples
--------
- >>> double = lambda x: x*2 # doctest: +SKIP
- >>> dsk = {'out': (add, 'i', 'd'), # doctest: +SKIP
+ >>> from flash.core.serve.dag.utils_test import add, inc
+ >>> double = lambda x: x*2
+ >>> dsk = {'out': (add, 'i', 'd'),
... 'i': (inc, 'x'),
... 'd': (double, 'y'),
... 'x': 1, 'y': 1}
- >>> inline_functions(dsk, [], [inc]) # doctest: +SKIP
- {'out': (add, (inc, 'x'), 'd'),
- 'd': (double, 'y'),
- 'x': 1, 'y': 1}
+ >>> inline_functions(dsk, [], [inc]) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+ {'x': 1,
+ 'out': (, (, 'x'), 'd'),
+ 'd': ( at ...>, 'y'),
+ 'y': 1}
Protect output keys. In the example below ``i`` is not inlined because it
is marked as an output key.
- >>> inline_functions(dsk, ['i', 'out'], [inc, double]) # doctest: +SKIP
- {'out': (add, 'i', (double, 'y')),
- 'i': (inc, 'x'),
- 'x': 1, 'y': 1}
+ >>> inline_functions(dsk, ['i', 'out'], [inc, double]) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+ {'y': 1,
+ 'out': (, 'i', ( at ...>, 'y')),
+ 'i': (, 'x'),
+ 'x': 1}
+
"""
if not fast_functions:
return dsk
@@ -328,9 +338,11 @@ def functions_of(task):
Examples
--------
- >>> task = (add, (mul, 1, 2), (inc, 3)) # doctest: +SKIP
- >>> functions_of(task) # doctest: +SKIP
- set([add, mul, inc])
+ >>> from flash.core.serve.dag.utils_test import add, inc, mul
+ >>> task = (add, (mul, 1, 2), (inc, 3))
+ >>> sorted(functions_of(task), key=str) # doctest: +ELLIPSIS
+ [, , ]
+
"""
funcs = set()
@@ -868,6 +880,7 @@ class SubgraphCallable:
A list of keys to be used as arguments to the callable.
name : str, optional
The name to use for the function.
+
"""
__slots__ = ("dsk", "outkey", "inkeys", "name")
diff --git a/src/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py
index fff934f783..69447fad22 100644
--- a/src/flash/core/serve/dag/order.py
+++ b/src/flash/core/serve/dag/order.py
@@ -74,6 +74,7 @@
difference exists between two keys, use the key name to break ties.
This relies on the regularity of graph constructors like dask.array to be a
good proxy for ordering. This is usually a good idea and a sane default.
+
"""
from collections import defaultdict
@@ -106,9 +107,11 @@ def order(dsk, dependencies=None):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> dsk = {'a': 1, 'b': 2, 'c': (inc, 'a'), 'd': (add, 'b', 'c')}
>>> order(dsk)
{'a': 0, 'c': 1, 'b': 2, 'd': 3}
+
"""
if not dsk:
return {}
@@ -158,6 +161,7 @@ def dependents_key(x):
"""Choose a path from our starting task to our tactical goal.
This path is connected to a large goal, but focuses on completing a small goal and being memory efficient.
+
"""
return (
# Focus on being memory-efficient
@@ -171,6 +175,7 @@ def dependencies_key(x):
"""Choose which dependency to run as part of a reverse DFS.
This is very similar to both ``initial_stack_key``.
+
"""
num_dependents = len(dependents[x])
(
@@ -540,6 +545,8 @@ def graph_metrics(dependencies, dependents, total_dependencies):
Examples
--------
+ >>> from flash.core.serve.dag.task import get_deps
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> dsk = {'a1': 1, 'b1': (inc, 'a1'), 'b2': (inc, 'a1'), 'c1': (inc, 'b1')}
>>> dependencies, dependents = get_deps(dsk)
>>> _, total_dependencies = ndependencies(dependencies, dependents)
@@ -614,6 +621,8 @@ def ndependencies(dependencies, dependents):
Examples
--------
+ >>> from flash.core.serve.dag.task import get_deps
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
>>> dependencies, dependents = get_deps(dsk)
>>> num_dependencies, total_dependencies = ndependencies(dependencies, dependents)
@@ -624,6 +633,7 @@ def ndependencies(dependencies, dependents):
-------
num_dependencies: Dict[key, int]
total_dependencies: Dict[key, int]
+
"""
num_needed = {}
result = {}
@@ -657,7 +667,7 @@ class StrComparable:
When comparing two objects of different types Python fails
- >>> 'a' < 1 # doctest: +SKIP
+ >>> 'a' < 1
Traceback (most recent call last):
...
TypeError: '<' not supported between instances of 'str' and 'int'
@@ -667,6 +677,7 @@ class StrComparable:
>>> StrComparable('a') < StrComparable(1)
False
+
"""
__slots__ = ("obj",)
diff --git a/src/flash/core/serve/dag/rewrite.py b/src/flash/core/serve/dag/rewrite.py
index 63ef792904..aebe058ba6 100644
--- a/src/flash/core/serve/dag/rewrite.py
+++ b/src/flash/core/serve/dag/rewrite.py
@@ -46,6 +46,7 @@ class Traverser:
current
The head of the current element in the traversal. This is simply `head`
applied to the attribute `term`.
+
"""
def __init__(self, term, stack=None):
@@ -64,6 +65,7 @@ def copy(self):
"""Copy the traverser in its current state.
This allows the traversal to be pushed onto a stack, for easy backtracking.
+
"""
return Traverser(self.term, deque(self._stack))
@@ -92,6 +94,7 @@ class Token:
"""A token object.
Used to express certain objects in the traversal of a task or pattern.
+
"""
def __init__(self, name):
@@ -172,6 +175,7 @@ class RewriteRule:
... else:
... return list, x
>>> rule = RewriteRule(lhs, repl_list, variables)
+
"""
def __init__(self, lhs, rhs, vars=()):
@@ -224,8 +228,8 @@ class RuleSet:
>>> rs.rewrite((add, 2, 0)) # Apply ruleset to single task
2
- >>> rs.rewrite((f, (g, 'a', 3))) # doctest: +SKIP
- (h, 'a', 3)
+ >>> rs.rewrite((f, (g, 'a', 3))) # doctest: +ELLIPSIS
+ (, 'a', 3)
>>> dsk = {'a': (add, 2, 0), # Apply ruleset to full dask graph
... 'b': (f, (g, 'a', 3))}
@@ -234,6 +238,7 @@ class RuleSet:
----------
rules : list
A list of `RewriteRule`s included in the `RuleSet`.
+
"""
def __init__(self, *rules):
@@ -255,6 +260,7 @@ def add(self, rule):
Parameters
----------
rule : RewriteRule
+
"""
if not isinstance(rule, RewriteRule):
@@ -288,6 +294,7 @@ def iter_matches(self, term):
Tuples of `(rule, subs)`, where `rule` is the rewrite rule being
matched, and `subs` is a dictionary mapping the variables in the lhs
of the rule to their matching values in the term.
+
"""
S = Traverser(term)
@@ -328,8 +335,10 @@ def rewrite(self, task, strategy="bottom_up"):
Suppose there was a function `add` that returned the sum of 2 numbers,
and another function `double` that returned twice its input:
- >>> add = lambda x, y: x + y
- >>> double = lambda x: 2*x
+ >>> def add(x, y):
+ ... return x + y
+ >>> def double(x):
+ ... return 2*x
Now suppose `double` was *significantly* faster than `add`, so
you'd like to replace all expressions `(add, x, x)` with `(double,
@@ -341,14 +350,14 @@ def rewrite(self, task, strategy="bottom_up"):
This can then be applied to terms to perform the rewriting:
>>> term = (add, (add, 2, 2), (add, 2, 2))
- >>> rs.rewrite(term) # doctest: +SKIP
- (double, (double, 2))
+ >>> rs.rewrite(term) # doctest: +ELLIPSIS
+ (, (, 2))
If we only wanted to apply this to the top level of the term, the
`strategy` kwarg can be set to "top_level".
- >>> rs.rewrite(term) # doctest: +SKIP
- (double, (add, 2, 2))
+ >>> rs.rewrite(term) # doctest: +ELLIPSIS
+ (, (, 2))
"""
return strategies[strategy](self, task)
@@ -421,6 +430,7 @@ def _process_match(rule, syms):
A dictionary of {vars : subterms} describing the substitution to make the
pattern equivalent with the term. Returns `None` if the match is
invalid.
+
"""
subs = {}
diff --git a/src/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py
index d1903d8819..59bb7875f8 100644
--- a/src/flash/core/serve/dag/task.py
+++ b/src/flash/core/serve/dag/task.py
@@ -19,6 +19,7 @@ def ishashable(x):
True
>>> ishashable([1])
False
+
"""
try:
hash(x)
@@ -32,10 +33,12 @@ def istask(x):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import inc
>>> istask((inc, 1))
True
>>> istask(1)
False
+
"""
return type(x) is tuple and x and callable(x[0])
@@ -64,6 +67,7 @@ def _execute_task(arg, cache):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> cache = {'x': 1, 'y': 2} # Compute tasks against a cache
>>> _execute_task((add, 'x', 1), cache) # Compute task in naive manner
2
@@ -77,6 +81,7 @@ def _execute_task(arg, cache):
[[1, 2], [2, 1]]
>>> _execute_task('foo', cache) # Passes through on non-keys
'foo'
+
"""
if isinstance(arg, list):
return [_execute_task(a, cache) for a in arg]
@@ -110,6 +115,7 @@ def get(dsk: dict, out: Sequence[str], cache: dict = None, sortkeys: List[str] =
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import inc
>>> d = {'x': 1, 'y': (inc, 'x')}
>>> get(d, 'x')
1
@@ -117,6 +123,7 @@ def get(dsk: dict, out: Sequence[str], cache: dict = None, sortkeys: List[str] =
2
>>> get(d, 'y', sortkeys=['x', 'y'])
2
+
"""
for k in flatten(out) if isinstance(out, list) else [out]:
if k not in dsk:
@@ -140,6 +147,7 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import add, inc
>>> dsk = {'x': 1,
... 'y': (inc, 'x'),
... 'z': (add, 'x', 'y'),
@@ -149,14 +157,15 @@ def get_dependencies(dsk, key=None, task=no_default, as_list=False):
set()
>>> get_dependencies(dsk, 'y')
{'x'}
- >>> get_dependencies(dsk, 'z') # doctest: +SKIP
- {'x', 'y'}
+ >>> sorted(get_dependencies(dsk, 'z'))
+ ['x', 'y']
>>> get_dependencies(dsk, 'w') # Only direct dependencies
{'z'}
>>> get_dependencies(dsk, 'a') # Ignore non-keys
{'x'}
>>> get_dependencies(dsk, task=(inc, 'x')) # provide tasks directly
{'x'}
+
"""
if key is not None:
arg = dsk[key]
@@ -194,12 +203,14 @@ def get_deps(dsk):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import inc
>>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
>>> dependencies, dependents = get_deps(dsk)
>>> dependencies
{'a': set(), 'b': {'a'}, 'c': {'b'}}
>>> dict(dependents)
{'a': {'b'}, 'b': {'c'}, 'c': set()}
+
"""
dependencies = {k: get_dependencies(dsk, task=v) for k, v in dsk.items()}
dependents = reverse_dict(dependencies)
@@ -233,8 +244,10 @@ def reverse_dict(d):
"""
>>> a, b, c = 'abc'
>>> d = {a: [b, c], b: [c]}
- >>> reverse_dict(d) # doctest: +SKIP
- {'a': set([]), 'b': set(['a']}, 'c': set(['a', 'b'])}
+ >>> dd = reverse_dict(d)
+ >>> from pprint import pprint
+ >>> pprint({k: sorted(v) for k, v in dd.items()})
+ {'a': [], 'b': ['a'], 'c': ['a', 'b']}
"""
result = defaultdict(set)
_add = set.add
@@ -251,8 +264,10 @@ def subs(task, key, val):
Examples
--------
- >>> subs((inc, 'x'), 'x', 1) # doctest: +SKIP
- (inc, 1)
+ >>> from flash.core.serve.dag.utils_test import inc
+ >>> subs((inc, 'x'), 'x', 1) # doctest: +ELLIPSIS
+ (, 1)
+
"""
type_task = type(task)
if not (type_task is tuple and task and callable(task[0])): # istask(task):
@@ -291,6 +306,7 @@ def _toposort(dsk, keys=None, returncycle=False, dependencies=None):
"""Stack-based depth-first search traversal.
This is based on Tarjan's method for topological sorting (see wikipedia for pseudocode).
+
"""
if keys is None:
keys = dsk
@@ -368,6 +384,7 @@ def getcycle(d, keys):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import inc
>>> d = {'x': (inc, 'z'), 'y': (inc, 'x'), 'z': (inc, 'y')}
>>> getcycle(d, 'x')
['x', 'z', 'y', 'x']
@@ -384,6 +401,7 @@ def isdag(d, keys):
Examples
--------
+ >>> from flash.core.serve.dag.utils_test import inc
>>> isdag({'x': 0, 'y': (inc, 'x')}, 'y')
True
>>> isdag({'x': (inc, 'y'), 'y': (inc, 'x')}, 'y')
@@ -420,8 +438,10 @@ def quote(x):
Examples
--------
- >>> quote((add, 1, 2)) # doctest: +SKIP
+ >>> from flash.core.serve.dag.utils_test import add
+ >>> quote((add, 1, 2))
(literal,)
+
"""
if istask(x) or type(x) is list or type(x) is dict:
return (literal(x),)
diff --git a/src/flash/core/serve/dag/utils.py b/src/flash/core/serve/dag/utils.py
index e90699cbae..5c5f188fdb 100644
--- a/src/flash/core/serve/dag/utils.py
+++ b/src/flash/core/serve/dag/utils.py
@@ -1,6 +1,4 @@
-"""
-NOTICE: Some methods in this file have been modified from their original source.
-"""
+"""NOTICE: Some methods in this file have been modified from their original source."""
import functools
import re
diff --git a/src/flash/core/serve/decorators.py b/src/flash/core/serve/decorators.py
index 0675d037ee..858b2139c5 100644
--- a/src/flash/core/serve/decorators.py
+++ b/src/flash/core/serve/decorators.py
@@ -107,6 +107,7 @@ def _validate_expose_inputs_outputs_args(kwargs: Dict[str, BaseType]):
>>> out = {'out': Number()}
>>> _validate_expose_inputs_outputs_args(inp)
>>> _validate_expose_inputs_outputs_args(out)
+
"""
if not isinstance(kwargs, dict):
raise TypeError(f"`expose` values must be {dict}. recieved {kwargs}")
@@ -152,6 +153,7 @@ def expose(inputs: Dict[str, BaseType], outputs: Dict[str, BaseType]):
TODO
----
* Examples in the docstring.
+
"""
_validate_expose_inputs_outputs_args(inputs)
_validate_expose_inputs_outputs_args(outputs)
diff --git a/src/flash/core/serve/execution.py b/src/flash/core/serve/execution.py
index 3b555660e0..330ac8242a 100644
--- a/src/flash/core/serve/execution.py
+++ b/src/flash/core/serve/execution.py
@@ -67,6 +67,7 @@ class TaskComposition:
pre_optimization_dsk
Merged component `_dsk` subgraphs (without payload / result
mapping or connections applied.)
+
"""
__slots__ = (
@@ -112,6 +113,7 @@ class UnprocessedTaskDask:
map of ouput (results) key to output task key
output_keys
keys to get as results
+
"""
__slots__ = (
@@ -150,6 +152,7 @@ def _process_initial(
Returns
-------
UnprocessedTaskDask
+
"""
# mapping payload input keys -> serialized keys / tasks
@@ -256,6 +259,7 @@ def build_composition(
``C_2_1 deserailize``from ``C_2`` / ``C_1``, we see here that since
endpoints define the path through the DAG, we cannot eliminate them
entirely either.
+
"""
initial_task_dsk = _process_initial(endpoint_protocol, components)
diff --git a/src/flash/core/serve/flash_components.py b/src/flash/core/serve/flash_components.py
index 40485eca90..acd5602f64 100644
--- a/src/flash/core/serve/flash_components.py
+++ b/src/flash/core/serve/flash_components.py
@@ -99,7 +99,6 @@ def predict(self, inputs):
inputs = self.model.transfer_batch_to_device(inputs, self.device)
inputs = self.on_after_batch_transfer(inputs, 0)
preds = self.model.predict_step(inputs, 0)
- preds = self.output_transform(preds)
- return preds
+ return self.output_transform(preds)
return FlashServeModelComponent(model)
diff --git a/src/flash/core/serve/interfaces/models.py b/src/flash/core/serve/interfaces/models.py
index 4d6c84b5b7..2177640b14 100644
--- a/src/flash/core/serve/interfaces/models.py
+++ b/src/flash/core/serve/interfaces/models.py
@@ -37,6 +37,7 @@ class EndpointProtocol:
class initializer. Component inputs & outputs (as defined in `@expose` object decorations) dtype method (`serialize`
and `deserialize`) type hints are inspected in order to constuct a specification unique to the endpoint, they are
returned as subclasses of pydantic ``BaseModel``.
+
"""
def __init__(self, name: str, endpoint: "Endpoint", components: Dict[str, "ModelComponent"]):
diff --git a/src/flash/core/serve/server.py b/src/flash/core/serve/server.py
index aeaf00c034..ed1af05222 100644
--- a/src/flash/core/serve/server.py
+++ b/src/flash/core/serve/server.py
@@ -20,6 +20,7 @@ class ServerMixin:
debug If the server should be started up in debug mode. By default, False. testing If the server should
return the ``app`` instance instead of blocking the process (via running the ``app`` in ``uvicorn``). This is
used when taking advantage of a server ``TestClient``. By default, False
+
"""
DEBUG: bool
@@ -37,6 +38,7 @@ def serve(self, host: str = "127.0.0.1", port: int = 8000):
host address to run the server on
port
port number to expose the running server on
+
"""
if FLASH_DISABLE_SERVE:
return None
diff --git a/src/flash/core/serve/types/base.py b/src/flash/core/serve/types/base.py
index 6ef42a8a2f..d530a9c42d 100644
--- a/src/flash/core/serve/types/base.py
+++ b/src/flash/core/serve/types/base.py
@@ -30,6 +30,7 @@ def deserialize(self, text: str, language: str):
.. code-block:: python
{"text": "some string", "language": "en"}
+
"""
@cached_property
@@ -54,6 +55,7 @@ def deserialize(self, *args, **kwargs): # pragma: no cover
"""Take the inputs from the network and deserialize/convert them.
Output from this method will go to the exposed method as arguments.
+
"""
raise NotImplementedError
@@ -64,5 +66,6 @@ def packed_deserialize(self, kwargs):
sophisticated datatypes (such as Repeated) where the developer wants to dictate how the unpacking happens. For
simple cases like Image or Bbox etc., developer would never need to know the existence of this. Task graph would
never call deserialize directly but always call this method.
+
"""
return self.deserialize(**kwargs)
diff --git a/src/flash/core/serve/types/bbox.py b/src/flash/core/serve/types/bbox.py
index ba9a98184d..e85d77d8f1 100644
--- a/src/flash/core/serve/types/bbox.py
+++ b/src/flash/core/serve/types/bbox.py
@@ -17,6 +17,7 @@ class BBox(BaseType):
like Javascript to use a dictionary with ``x1, y1, x2 and y2`` as keys, we went
with DL convention which is to use a list/tuple in which four floats are
arranged in the same ``order -> x1, y1, x2, y2``
+
"""
def __post_init__(self):
diff --git a/src/flash/core/serve/types/image.py b/src/flash/core/serve/types/image.py
index e20b94d884..8641916fda 100644
--- a/src/flash/core/serve/types/image.py
+++ b/src/flash/core/serve/types/image.py
@@ -40,6 +40,7 @@ class Image(BaseType):
"I": 1, # (32-bit signed integer pixels)
"F": 1, # (32-bit floating point pixels)
}
+
"""
height: Optional[int] = None
diff --git a/src/flash/core/serve/types/label.py b/src/flash/core/serve/types/label.py
index a5ad295016..cb1da78a2a 100644
--- a/src/flash/core/serve/types/label.py
+++ b/src/flash/core/serve/types/label.py
@@ -21,6 +21,7 @@ class Label(BaseType):
classes
A list, tuple or a dict of classes. If it's list or a tuple, index of the
class, is the key. If it's a dictionary, the key must be an integer
+
"""
path: Union[str, Path, None] = field(default=None)
diff --git a/src/flash/core/serve/types/repeated.py b/src/flash/core/serve/types/repeated.py
index 5efa86902b..63498d87fa 100644
--- a/src/flash/core/serve/types/repeated.py
+++ b/src/flash/core/serve/types/repeated.py
@@ -18,6 +18,7 @@ class Repeated(BaseType):
Optional parameter specifying if there is a maximum length of the
repeated elements (`int > 0`). If `max_len=None`, there can be any
number of repeated elements. By default: `None`.
+
"""
dtype: BaseType
diff --git a/src/flash/core/serve/types/table.py b/src/flash/core/serve/types/table.py
index 7fe1fb7a33..a073af02af 100644
--- a/src/flash/core/serve/types/table.py
+++ b/src/flash/core/serve/types/table.py
@@ -53,6 +53,7 @@ class Table(BaseType):
* It might be better to remove pandas dependency to gain performance however we
are offloading the validation logic to pandas which would have been painful if
we were to do custom built logic
+
"""
column_names: List[str]
diff --git a/src/flash/core/serve/types/text.py b/src/flash/core/serve/types/text.py
index dfeda9a59d..62586b1c96 100644
--- a/src/flash/core/serve/types/text.py
+++ b/src/flash/core/serve/types/text.py
@@ -22,6 +22,7 @@ class Text(BaseType):
TODO: Allow other arguments such as language, max_len etc. Add guidelines
to write custom tokenizer
+
"""
tokenizer: Union[str, Any]
diff --git a/src/flash/core/serve/utils.py b/src/flash/core/serve/utils.py
index 472493e47c..67585d6105 100644
--- a/src/flash/core/serve/utils.py
+++ b/src/flash/core/serve/utils.py
@@ -9,6 +9,7 @@ def fn_outputs_to_keyed_map(serialize_fn_out_keys, fn_output) -> Dict[str, Any]:
"""Convert outputs of a function to a dict of `{result_name: values}`
accepts function outputs which are sequence, dict, or object.
+
"""
if len(serialize_fn_out_keys) == 1:
if not isinstance(fn_output, dict):
@@ -33,6 +34,7 @@ def download_file(url: str, *, download_path: Optional[Path] = None) -> str:
----
* cleanup on error
* allow specific file names
+
"""
fname = f"{url.split('/')[-1]}"
fpath = str(download_path.absolute()) if download_path is not None else f"./{fname}"
diff --git a/src/flash/core/trainer.py b/src/flash/core/trainer.py
index fa40cb7f4f..fe7e39b83a 100644
--- a/src/flash/core/trainer.py
+++ b/src/flash/core/trainer.py
@@ -74,6 +74,13 @@ def insert_env_defaults(self, *args, **kwargs):
class Trainer(PlTrainer):
+ """Exteded Trainer for FLash tasks.
+
+ >>> Trainer() # doctest: +ELLIPSIS
+ <...trainer.Trainer object at ...>
+
+ """
+
@_defaults_from_env_vars
def __init__(self, *args, **kwargs):
if flash._IS_TESTING:
@@ -178,6 +185,7 @@ def predict(
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
+
"""
# Note: Prediction on TPU device with multi cores is not supported yet
if isinstance(self.accelerator, TPUAccelerator) and self.num_devices > 1:
@@ -256,5 +264,6 @@ def configure_optimizers(self):
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
)
return [optimizer], [scheduler]
+
"""
return super().estimated_stepping_batches
diff --git a/src/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py
index 1de8f5f9df..132fc85479 100644
--- a/src/flash/core/utilities/flash_cli.py
+++ b/src/flash/core/utilities/flash_cli.py
@@ -20,8 +20,7 @@
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
import pytorch_lightning as pl
-from jsonargparse import ArgumentParser
-from jsonargparse.signatures import get_class_signature_functions
+from jsonargparse import ArgumentParser, class_from_function
from lightning_utilities.core.overrides import is_overridden
from pytorch_lightning import LightningModule, Trainer
@@ -31,7 +30,6 @@
LightningArgumentParser,
LightningCLI,
SaveConfigCallback,
- class_from_function,
)
from flash.core.utilities.stability import beta
@@ -107,6 +105,16 @@ def wrapper(*args, **kwargs):
return wrapper
+def get_class_signature_functions(classes):
+ signatures = []
+ for num, cls in enumerate(classes):
+ if cls.__new__ is not object.__new__ and not any(cls.__new__ is c.__new__ for c in classes[num + 1 :]):
+ signatures.append((cls, cls.__new__))
+ if not any(cls.__init__ is c.__init__ for c in classes[num + 1 :]):
+ signatures.append((cls, cls.__init__))
+ return signatures
+
+
def get_overlapping_args(func_a, func_b) -> Set[str]:
func_a = get_class_signature_functions([func_a])[0][1]
func_b = get_class_signature_functions([func_b])[0][1]
@@ -214,7 +222,7 @@ def add_arguments_to_parser(self, parser) -> None:
def add_subcommand_from_function(self, subcommands, function, function_name=None):
subcommand = ArgumentParser()
if get_kwarg_name(function) == "data_module_kwargs":
- datamodule_function = class_from_function(function, return_type=self.local_datamodule_class)
+ datamodule_function = class_from_function(function, self.local_datamodule_class)
subcommand.add_class_arguments(
datamodule_function,
fail_untyped=False,
@@ -233,7 +241,7 @@ def add_subcommand_from_function(self, subcommands, function, function_name=None
},
)
else:
- datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
+ datamodule_function = class_from_function(drop_kwargs(function), self.local_datamodule_class)
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
subcommand_name = function_name or function.__name__
subcommands.add_subcommand(subcommand_name, subcommand)
diff --git a/src/flash/core/utilities/imports.py b/src/flash/core/utilities/imports.py
index 7c3bb75ef8..cb118ddf86 100644
--- a/src/flash/core/utilities/imports.py
+++ b/src/flash/core/utilities/imports.py
@@ -101,13 +101,7 @@ class Image:
_TOPIC_TABULAR_AVAILABLE = all([_PANDAS_AVAILABLE, _FORECASTING_AVAILABLE, _PYTORCHTABULAR_AVAILABLE])
_TOPIC_VIDEO_AVAILABLE = all([_TORCHVISION_AVAILABLE, _PIL_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, _KORNIA_AVAILABLE])
_TOPIC_IMAGE_AVAILABLE = all(
- [
- _TORCHVISION_AVAILABLE,
- _TIMM_AVAILABLE,
- _PIL_AVAILABLE,
- _ALBUMENTATIONS_AVAILABLE,
- _PYSTICHE_AVAILABLE,
- ]
+ [_TORCHVISION_AVAILABLE, _TIMM_AVAILABLE, _PIL_AVAILABLE, _ALBUMENTATIONS_AVAILABLE, _PYSTICHE_AVAILABLE]
)
_TOPIC_SERVE_AVAILABLE = all([_FASTAPI_AVAILABLE, _PYDANTIC_AVAILABLE, _CYTOOLZ_AVAILABLE, _UVICORN_AVAILABLE])
_TOPIC_POINTCLOUD_AVAILABLE = all([_OPEN3D_AVAILABLE, _TORCHVISION_AVAILABLE])
@@ -117,7 +111,7 @@ class Image:
_TOPIC_GRAPH_AVAILABLE = all(
[_TORCH_SCATTER_AVAILABLE, _TORCH_SPARSE_AVAILABLE, _TORCH_GEOMETRIC_AVAILABLE, _NETWORKX_AVAILABLE]
)
-_TOPIC_CORE_AVAILABLE = _TOPIC_IMAGE_AVAILABLE and _TOPIC_TABULAR_AVAILABLE and _TOPIC_TEXT_AVAILABLE
+_TOPIC_CORE_AVAILABLE = all([_TOPIC_IMAGE_AVAILABLE, _TOPIC_TABULAR_AVAILABLE, _TOPIC_TEXT_AVAILABLE])
_EXTRAS_AVAILABLE = {
"image": _TOPIC_IMAGE_AVAILABLE,
@@ -192,6 +186,7 @@ def lazy_import(module_name, callback=None):
Returns:
a proxy module object that will be lazily imported when first used
+
"""
return LazyModule(module_name, callback=callback)
@@ -203,6 +198,7 @@ class LazyModule(types.ModuleType):
module_name: the fully-qualified module name to import
callback (None): a callback function to call before importing the
module
+
"""
def __init__(self, module_name, callback=None):
diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py
index 37ce4a470e..898df68854 100644
--- a/src/flash/core/utilities/lightning_cli.py
+++ b/src/flash/core/utilities/lightning_cli.py
@@ -4,14 +4,11 @@
import os
import warnings
from argparse import Namespace
-from functools import wraps
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
import torch
-from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode
-from jsonargparse.signatures import ClassFromFunctionBase
-from jsonargparse.typehints import ClassType
+from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, set_config_read_mode
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
@@ -25,46 +22,17 @@
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
-def class_from_function(
- func: Callable[..., ClassType],
- return_type: Optional[Type[ClassType]] = None,
-) -> Type[ClassType]:
- """Creates a dynamic class which if instantiated is equivalent to calling func.
-
- Args:
- func: A function that returns an instance of a class. It must have a return type annotation.
- """
-
- @wraps(func)
- def __new__(cls, *args, **kwargs):
- return func(*args, **kwargs)
-
- if return_type is None:
- return_type = inspect.signature(func).return_annotation
-
- if isinstance(return_type, str):
- raise RuntimeError("Classmethod instantiation is not supported when the return type annotation is a string.")
-
- class ClassFromFunction(return_type, ClassFromFunctionBase): # type: ignore
- pass
-
- ClassFromFunction.__new__ = __new__ # type: ignore
- ClassFromFunction.__doc__ = func.__doc__
- ClassFromFunction.__name__ = func.__name__
-
- return ClassFromFunction
-
-
class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
- def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
For full details of accepted arguments see
`ArgumentParser.__init__ `_.
+
"""
- super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs)
+ super().__init__(*args, **kwargs)
self.add_argument(
"--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format."
)
@@ -89,13 +57,14 @@ def add_lightning_class_args(
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
nested_key: Name of the nested namespace to store arguments.
subclass_mode: Whether allow any subclass of the given class.
+
"""
if callable(lightning_class) and not inspect.isclass(lightning_class):
lightning_class = class_from_function(lightning_class)
if inspect.isclass(lightning_class) and issubclass(
cast(type, lightning_class),
- (Trainer, LightningModule, LightningDataModule, Callback, ClassFromFunctionBase),
+ (Trainer, LightningModule, LightningDataModule, Callback),
):
if issubclass(cast(type, lightning_class), Callback):
self.callback_keys.append(nested_key)
@@ -124,6 +93,7 @@ def add_optimizer_args(
optimizer_class: Any subclass of torch.optim.Optimizer.
nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
+
"""
if isinstance(optimizer_class, tuple):
assert all(issubclass(o, Optimizer) for o in optimizer_class)
@@ -152,6 +122,7 @@ def add_lr_scheduler_args(
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``.
nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
+
"""
if isinstance(lr_scheduler_class, tuple):
assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class)
@@ -174,6 +145,7 @@ class SaveConfigCallback(Callback):
Raises:
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
+
"""
def __init__(
@@ -342,6 +314,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Args:
parser: The argument parser object to which arguments can be added
+
"""
def link_optimizers_and_lr_schedulers(self) -> None:
@@ -392,6 +365,7 @@ def add_configure_optimizers_method_to_model(self) -> None:
If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a
`configure_optimizers` method is automatically implemented in the model class.
+
"""
def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
@@ -486,6 +460,7 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -
Returns:
The instantiated class object.
+
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
diff --git a/src/flash/core/utilities/stages.py b/src/flash/core/utilities/stages.py
index 5e6e653580..995bbd95e9 100644
--- a/src/flash/core/utilities/stages.py
+++ b/src/flash/core/utilities/stages.py
@@ -26,6 +26,7 @@ class RunningStage(LightningEnum):
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.SERVING`` - ``RunningStage.SERVING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
+
"""
TRAINING = "train"
diff --git a/src/flash/graph/embedding/model.py b/src/flash/graph/embedding/model.py
index 3ddd53a38a..e214c25cf1 100644
--- a/src/flash/graph/embedding/model.py
+++ b/src/flash/graph/embedding/model.py
@@ -47,8 +47,7 @@ def __init__(self, backbone: nn.Module, pooling_fn: Optional[Union[str, Callable
def forward(self, data) -> Tensor:
x = self.backbone(data.x, data.edge_index)
- x = self.pooling_fn(x, data.batch)
- return x
+ return self.pooling_fn(x, data.batch)
def training_step(self, batch: Any, batch_idx: int) -> Any:
raise NotImplementedError("Training a `GraphEmbedder` is not supported. Use a `GraphClassifier` instead.")
diff --git a/src/flash/image/classification/backbones/resnet.py b/src/flash/image/classification/backbones/resnet.py
index 0a5e0c6edb..e47e92bec0 100644
--- a/src/flash/image/classification/backbones/resnet.py
+++ b/src/flash/image/classification/backbones/resnet.py
@@ -96,9 +96,7 @@ def forward(self, x: Tensor) -> Tensor:
identity = self.downsample(x)
out += identity
- out = self.relu(out)
-
- return out
+ return self.relu(out)
class Bottleneck(nn.Module):
@@ -155,9 +153,7 @@ def forward(self, x: Tensor) -> Tensor:
identity = self.downsample(x)
out += identity
- out = self.relu(out)
-
- return out
+ return self.relu(out)
class ResNet(nn.Module):
@@ -300,9 +296,7 @@ def forward(self, x: Tensor) -> Tensor:
x = self.layer4(x)
x = self.avgpool(x)
- x = torch.flatten(x, 1)
-
- return x
+ return torch.flatten(x, 1)
def _resnet(
diff --git a/src/flash/image/classification/data.py b/src/flash/image/classification/data.py
index a4e54e6d18..3bf13bb91e 100644
--- a/src/flash/image/classification/data.py
+++ b/src/flash/image/classification/data.py
@@ -56,17 +56,9 @@
# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _TOPIC_IMAGE_AVAILABLE:
- __doctest_skip__ += [
- "ImageClassificationData",
- "ImageClassificationData.from_files",
- "ImageClassificationData.from_folders",
- "ImageClassificationData.from_numpy",
- "ImageClassificationData.from_images",
- "ImageClassificationData.from_tensors",
- "ImageClassificationData.from_data_frame",
- "ImageClassificationData.from_csv",
- "ImageClassificationData.from_fiftyone",
- ]
+ __doctest_skip__ += ["ImageClassificationData", "ImageClassificationData.*"]
+if not _FIFTYONE_AVAILABLE:
+ __doctest_skip__ += ["ImageClassificationData.from_fiftyone"]
class ImageClassificationData(DataModule):
@@ -1067,11 +1059,10 @@ def from_labelstudio(
multi_label: Optional[bool] = False,
**data_module_kwargs: Any,
) -> "ImageClassificationData":
- """Creates a :class:`~flash.core.data.data_module.DataModule` object
- from the given export file and data directory using the
- :class:`~flash.core.data.io.input.Input` of name
- :attr:`~flash.core.data.io.input.InputFormat.FOLDERS`
- from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`.
+ """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given export file and data
+ directory using the :class:`~flash.core.data.io.input.Input` of name
+ :attr:`~flash.core.data.io.input.InputFormat.FOLDERS` from the passed or constructed
+ :class:`~flash.core.data.io.input_transform.InputTransform`.
Args:
export_json: path to label studio export file
@@ -1146,9 +1137,8 @@ def from_datasets(
**data_module_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given datasets using the
- :class:`~flash.core.data.io.input.Input`
- of name :attr:`~flash.core.data.io.input.InputFormat.DATASETS`
- from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`.
+ :class:`~flash.core.data.io.input.Input` of name :attr:`~flash.core.data.io.input.InputFormat.DATASETS` from the
+ passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`.
Args:
train_dataset: Dataset used during training.
@@ -1197,7 +1187,7 @@ class MatplotlibVisualization(BaseVisualization):
block_viz_window: bool = True # parameter to allow user to block visualisation windows
@staticmethod
- @requires("image")
+ @requires("PIL")
def _to_numpy(img: Union[np.ndarray, Tensor, Image.Image]) -> np.ndarray:
out: np.ndarray
if isinstance(img, np.ndarray):
diff --git a/src/flash/image/classification/integrations/learn2learn.py b/src/flash/image/classification/integrations/learn2learn.py
index d15d6b39ad..4939135256 100644
--- a/src/flash/image/classification/integrations/learn2learn.py
+++ b/src/flash/image/classification/integrations/learn2learn.py
@@ -11,10 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Note: This file will be deleted once
-https://github.com/learnables/learn2learn/pull/257/files is merged within Learn2Learn.
-"""
+"""Note: This file will be deleted once https://github.com/learnables/learn2learn/pull/257/files is merged within
+Learn2Learn."""
from typing import Any, Callable, Optional
@@ -44,6 +42,7 @@ def __init__(
epoch_length: The expected epoch length. This requires to be divisible by devices.
devices: Number of devices being used.
collate_fn: The collate_fn to be applied on multiple tasks
+
"""
self.tasks = tasks
self.epoch_length = epoch_length
@@ -99,6 +98,7 @@ def __init__(
num_workers: Number of workers to be provided to the DataLoader.
epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size).
seed: The seed will be used on __iter__ call and should be the same for all processes.
+
"""
self.taskset = taskset
self.global_rank = global_rank
diff --git a/src/flash/image/detection/data.py b/src/flash/image/detection/data.py
index 9884ab9889..b6c8135675 100644
--- a/src/flash/image/detection/data.py
+++ b/src/flash/image/detection/data.py
@@ -51,9 +51,14 @@
VOCBBoxParser = object
Parser = object
+__doctest_skip__ = []
# Skip doctests if requirements aren't available
if not _TOPIC_IMAGE_AVAILABLE:
- __doctest_skip__ = ["ObjectDetectionData", "ObjectDetectionData.*"]
+ __doctest_skip__ += ["ObjectDetectionData", "ObjectDetectionData.*"]
+
+
+if not _FIFTYONE_AVAILABLE:
+ __doctest_skip__ += ["ObjectDetectionData.from_fiftyone"]
class ObjectDetectionData(DataModule):
diff --git a/src/flash/image/embedding/heads/vissl_heads.py b/src/flash/image/embedding/heads/vissl_heads.py
index 316f609342..4e69f0b258 100644
--- a/src/flash/image/embedding/heads/vissl_heads.py
+++ b/src/flash/image/embedding/heads/vissl_heads.py
@@ -41,6 +41,7 @@ class SimCLRHead(nn.Module):
model_config: Model config AttrDict from VISSL
dims: list of dimensions for creating a projection head
use_bn: use batch-norm after each linear layer or not
+
"""
def __init__(
@@ -84,7 +85,7 @@ def forward(self, x: Tensor) -> Tensor:
return self.clf(x)
-if _VISSL_AVAILABLE:
+if _VISSL_AVAILABLE and "simclr_head" not in MODEL_HEADS_REGISTRY:
SimCLRHead = register_model_head("simclr_head")(SimCLRHead)
diff --git a/src/flash/image/embedding/strategies/default.py b/src/flash/image/embedding/strategies/default.py
index 423a939128..f7e61f971d 100644
--- a/src/flash/image/embedding/strategies/default.py
+++ b/src/flash/image/embedding/strategies/default.py
@@ -74,6 +74,7 @@ def default(head: Optional[str] = None, loss_fn: Optional[str] = None, **kwargs)
"""Return `(None, None, [])` as loss function, head and hooks.
Because default strategy only support prediction.
+
"""
if head is not None:
warnings.warn(f"default strategy has no heads. So given head({head}) is ignored.")
diff --git a/src/flash/image/embedding/vissl/adapter.py b/src/flash/image/embedding/vissl/adapter.py
index 96f192ab89..ad901f7ed1 100644
--- a/src/flash/image/embedding/vissl/adapter.py
+++ b/src/flash/image/embedding/vissl/adapter.py
@@ -47,8 +47,7 @@ def __init__(self, backbone: nn.Module):
def forward(self, x, *args, **kwargs):
x = self.backbone(x)
- x = x.unsqueeze(0)
- return x
+ return x.unsqueeze(0)
class MockVISSLTask:
diff --git a/src/flash/image/embedding/vissl/transforms/multicrop.py b/src/flash/image/embedding/vissl/transforms/multicrop.py
index cddd57d9f7..357de4e639 100644
--- a/src/flash/image/embedding/vissl/transforms/multicrop.py
+++ b/src/flash/image/embedding/vissl/transforms/multicrop.py
@@ -44,6 +44,7 @@ class StandardMultiCropSSLTransform(InputTransform):
gaussian_blur (bool): Specifies if the transforms' composition has Gaussian Blur
jitter_strength (float): Specify the coefficient for color jitter transform
normalize (Optional): Normalize transform from torchvision with params set according to the dataset
+
"""
total_num_crops: int = 2
diff --git a/src/flash/image/embedding/vissl/transforms/utilities.py b/src/flash/image/embedding/vissl/transforms/utilities.py
index e8658fecf4..915ee084a7 100644
--- a/src/flash/image/embedding/vissl/transforms/utilities.py
+++ b/src/flash/image/embedding/vissl/transforms/utilities.py
@@ -33,6 +33,7 @@ def multicrop_collate_fn(samples):
"""Multi-crop collate function for VISSL integration.
Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT
+
"""
result = vissl_collate_helper(samples)
@@ -55,6 +56,7 @@ def simclr_collate_fn(samples):
"""Multi-crop collate function for VISSL integration.
Run custom collate on a single key since VISSL transforms affect only DataKeys.INPUT
+
"""
result = vissl_collate_helper(samples)
diff --git a/src/flash/image/face_detection/input_transform.py b/src/flash/image/face_detection/input_transform.py
index 9a889bd414..83d8c09747 100644
--- a/src/flash/image/face_detection/input_transform.py
+++ b/src/flash/image/face_detection/input_transform.py
@@ -33,6 +33,7 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence
"""Collate function from fastface.
Organizes individual elements in a batch, calls prepare_batch from fastface and prepares the targets.
+
"""
samples = {key: [sample[key] for sample in samples] for key in samples[0]}
diff --git a/src/flash/image/face_detection/model.py b/src/flash/image/face_detection/model.py
index ff77bf8ab7..50e4b7ae20 100644
--- a/src/flash/image/face_detection/model.py
+++ b/src/flash/image/face_detection/model.py
@@ -106,13 +106,10 @@ def forward(self, x: List[Tensor]) -> Any:
# preds: Tensor(B, N, 5)
# preds: Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx
preds = self.model.logits_to_preds(logits)
- preds = self.model._postprocess(preds)
-
- return preds
+ return self.model._postprocess(preds)
def _prepare_batch(self, batch):
- batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std
- return batch
+ return (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std
def _compute_metrics(self, logits, targets):
# preds: Tensor(B, N, 5)
diff --git a/src/flash/image/face_detection/output_transform.py b/src/flash/image/face_detection/output_transform.py
index df9d4064ce..ac58ea486e 100644
--- a/src/flash/image/face_detection/output_transform.py
+++ b/src/flash/image/face_detection/output_transform.py
@@ -36,6 +36,4 @@ def per_batch_transform(batch: Any) -> Any:
# preds: list of Tensor(N, 5) as x1, y1, x2, y2, score
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))]
- preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
-
- return preds
+ return ff.utils.preprocess.adjust_results(preds, scales, paddings)
diff --git a/src/flash/image/segmentation/data.py b/src/flash/image/segmentation/data.py
index 363f5a6df4..c9f98b9a46 100644
--- a/src/flash/image/segmentation/data.py
+++ b/src/flash/image/segmentation/data.py
@@ -19,7 +19,7 @@
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, lazy_import
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, lazy_import
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.segmentation.input import (
@@ -41,15 +41,10 @@
# Skip doctests if requirements aren't available
__doctest_skip__ = []
-if not _TOPIC_IMAGE_AVAILABLE:
- __doctest_skip__ += [
- "SemanticSegmentationData",
- "SemanticSegmentationData.from_files",
- "SemanticSegmentationData.from_folders",
- "SemanticSegmentationData.from_numpy",
- "SemanticSegmentationData.from_tensors",
- "SemanticSegmentationData.from_fiftyone",
- ]
+if not _SEGMENTATION_MODELS_AVAILABLE:
+ __doctest_skip__ += ["SemanticSegmentationData", "SemanticSegmentationData.*"]
+if not _FIFTYONE_AVAILABLE:
+ __doctest_skip__ += ["SemanticSegmentationData.from_fiftyone"]
class SemanticSegmentationData(DataModule):
diff --git a/src/flash/image/style_transfer/data.py b/src/flash/image/style_transfer/data.py
index 22a9bbda97..8a3bb45dd3 100644
--- a/src/flash/image/style_transfer/data.py
+++ b/src/flash/image/style_transfer/data.py
@@ -18,7 +18,6 @@
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
-from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE
from flash.core.utilities.stability import beta
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
@@ -27,8 +26,9 @@
from flash.image.style_transfer.input_transform import StyleTransferInputTransform
# Skip doctests if requirements aren't available
-if not _TOPIC_IMAGE_AVAILABLE:
- __doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"]
+# if not _TOPIC_IMAGE_AVAILABLE:
+# skipping as there are some connection/download issues
+__doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"]
@beta("Style transfer is currently in Beta.")
diff --git a/src/flash/pointcloud/detection/open3d_ml/app.py b/src/flash/pointcloud/detection/open3d_ml/app.py
index 07f755f1bc..ea7c14db92 100644
--- a/src/flash/pointcloud/detection/open3d_ml/app.py
+++ b/src/flash/pointcloud/detection/open3d_ml/app.py
@@ -41,6 +41,7 @@ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768
indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
width: The width of the visualization window.
height: The height of the visualization window.
+
"""
# Setup the labels
lut = LabelLUT()
diff --git a/src/flash/pointcloud/segmentation/open3d_ml/app.py b/src/flash/pointcloud/segmentation/open3d_ml/app.py
index 7d3eab3f23..dc22623980 100644
--- a/src/flash/pointcloud/segmentation/open3d_ml/app.py
+++ b/src/flash/pointcloud/segmentation/open3d_ml/app.py
@@ -44,6 +44,7 @@ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768
indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
width: The width of the visualization window.
height: The height of the visualization window.
+
"""
# Setup the labels
lut = LabelLUT()
diff --git a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py
index 6f7a4fcc53..a55acb196a 100644
--- a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py
+++ b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py
@@ -134,6 +134,7 @@ def get_label_to_names(self):
Returns:
A dict where keys are label numbers and
values are the corresponding names.
+
"""
return self.meta["label_to_names"]
diff --git a/src/flash/tabular/classification/model.py b/src/flash/tabular/classification/model.py
index 15f83a6676..4c36eac568 100644
--- a/src/flash/tabular/classification/model.py
+++ b/src/flash/tabular/classification/model.py
@@ -102,33 +102,30 @@ def __init__(
@property
def data_parameters(self) -> Dict[str, Any]:
- """Get the parameters computed from the training data used to create this
- :class:`~flash.tabular.classification.TabularClassifier`. Use these parameters to load data for
- evaluation / prediction.
-
- Examples
- ________
-
- .. doctest::
-
- >>> import flash
- >>> from flash.core.data.utils import download_data
- >>> from flash.tabular import TabularClassificationData, TabularClassifier
- >>> download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
- >>> model = TabularClassifier.load_from_checkpoint(
- ... "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
- ... )
- >>> datamodule = TabularClassificationData.from_csv(
- ... predict_file="data/titanic/predict.csv",
- ... parameters=model.data_parameters,
- ... batch_size=8,
- ... )
- >>> trainer = flash.Trainer()
- >>> trainer.predict(
- ... model,
- ... datamodule=datamodule,
- ... output="classes",
- ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+ r"""Get the parameters computed from the training data used to create this
+ :class:`~flash.tabular.classification.TabularClassifier`. Use these parameters to load data for evaluation /
+ prediction.
+
+ Example::
+
+ import flash
+ from flash.core.data.utils import download_data
+ from flash.tabular import TabularClassificationData, TabularClassifier
+ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
+ model = TabularClassifier.load_from_checkpoint(
+ "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
+ )
+ datamodule = TabularClassificationData.from_csv(
+ predict_file="data/titanic/predict.csv",
+ parameters=model.data_parameters,
+ batch_size=8,
+ )
+ trainer = flash.Trainer()
+ trainer.predict(
+ model,
+ datamodule=datamodule,
+ output="classes",
+ )
Predicting...
"""
return self._data_parameters
diff --git a/src/flash/tabular/classification/utils.py b/src/flash/tabular/classification/utils.py
index 9c208e9105..947c74a599 100644
--- a/src/flash/tabular/classification/utils.py
+++ b/src/flash/tabular/classification/utils.py
@@ -80,8 +80,7 @@ def _pre_transform(
) -> DataFrame:
df = _impute(df, num_cols)
df = _normalize(df, num_cols, mean=mean, std=std)
- df = _categorize(df, cat_cols, codes=codes)
- return df
+ return _categorize(df, cat_cols, codes=codes)
def _to_cat_vars_numpy(df: DataFrame, cat_cols: List[str]) -> list:
diff --git a/src/flash/tabular/regression/model.py b/src/flash/tabular/regression/model.py
index cbbac0f6da..9bae8f76b3 100644
--- a/src/flash/tabular/regression/model.py
+++ b/src/flash/tabular/regression/model.py
@@ -98,33 +98,29 @@ def __init__(
@property
def data_parameters(self) -> Dict[str, Any]:
- """Get the parameters computed from the training data used to create this
- :class:`~flash.tabular.regression.TabularRegressor`. Use these parameters to load data for
- evaluation / prediction.
-
- Examples
- ________
-
- .. doctest::
-
- >>> import flash
- >>> from flash.core.data.utils import download_data
- >>> from flash.tabular import TabularRegressionData, TabularRegressor
- >>> download_data("https://pl-flash-data.s3.amazonaws.com/SeoulBikeData.csv", "./data")
- >>> model = TabularRegressor.load_from_checkpoint(
- ... "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_regression_model.pt"
- ... )
- >>> datamodule = TabularRegressionData.from_csv(
- ... predict_file="data/SeoulBikeData.csv",
- ... parameters=model.data_parameters,
- ... batch_size=8,
- ... )
- >>> trainer = flash.Trainer()
- >>> trainer.predict(
- ... model,
- ... datamodule=datamodule,
- ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
- Predicting...
+ r"""Get the parameters computed from the training data used to create this
+ :class:`~flash.tabular.regression.TabularRegressor`. Use these parameters to load data for evaluation /
+ prediction.
+
+ Example::
+
+ import flash
+ from flash.core.data.utils import download_data
+ from flash.tabular import TabularRegressionData, TabularRegressor
+ download_data("https://pl-flash-data.s3.amazonaws.com/SeoulBikeData.csv", "./data")
+ model = TabularRegressor.load_from_checkpoint(
+ "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_regression_model.pt"
+ )
+ datamodule = TabularRegressionData.from_csv(
+ predict_file="data/SeoulBikeData.csv",
+ parameters=model.data_parameters,
+ batch_size=8,
+ )
+ trainer = flash.Trainer()
+ trainer.predict(
+ model,
+ datamodule=datamodule,
+ )
"""
return self._data_parameters
diff --git a/src/flash/template/classification/data.py b/src/flash/template/classification/data.py
index ce0318f758..9b6ebdeba9 100644
--- a/src/flash/template/classification/data.py
+++ b/src/flash/template/classification/data.py
@@ -90,6 +90,7 @@ def predict_load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]:
Returns:
A sequence of samples / sample metadata.
+
"""
return super().load_data(data.data)
diff --git a/src/flash/text/classification/data.py b/src/flash/text/classification/data.py
index a5ef47cdb3..13588a8796 100644
--- a/src/flash/text/classification/data.py
+++ b/src/flash/text/classification/data.py
@@ -786,11 +786,10 @@ def from_labelstudio(
multi_label: Optional[bool] = False,
**data_module_kwargs: Any,
) -> "TextClassificationData":
- """Creates a :class:`~flash.core.data.data_module.DataModule` object
- from the given export file and data directory using the
- :class:`~flash.core.data.io.input.Input` of name
- :attr:`~flash.core.data.io.input.InputFormat.FOLDERS`
- from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`.
+ """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given export file and data
+ directory using the :class:`~flash.core.data.io.input.Input` of name
+ :attr:`~flash.core.data.io.input.InputFormat.FOLDERS` from the passed or constructed
+ :class:`~flash.core.data.io.input_transform.InputTransform`.
Args:
export_json: path to label studio export file
diff --git a/src/flash/text/seq2seq/core/model.py b/src/flash/text/seq2seq/core/model.py
index 8a00672c78..e988258f69 100644
--- a/src/flash/text/seq2seq/core/model.py
+++ b/src/flash/text/seq2seq/core/model.py
@@ -83,6 +83,7 @@ class Seq2SeqTask(Task):
learning_rate: Learning rate to use for training, defaults to `3e-4`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
+
"""
required_extras: str = "text"
diff --git a/src/flash/text/seq2seq/summarization/data.py b/src/flash/text/seq2seq/summarization/data.py
index ca8c21eb19..07e334ebb9 100644
--- a/src/flash/text/seq2seq/summarization/data.py
+++ b/src/flash/text/seq2seq/summarization/data.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from typing import Any, Dict, List, Optional, Type
from flash.core.data.data_module import DataModule
@@ -27,9 +28,14 @@
else:
Dataset = object
+__doctest_skip__ = []
# Skip doctests if requirements aren't available
if not _TOPIC_TEXT_AVAILABLE:
- __doctest_skip__ = ["SummarizationData", "SummarizationData.*"]
+ __doctest_skip__ += ["SummarizationData", "SummarizationData.*"]
+
+# some strange crash for out of memory with pt 1.11
+if os.name == "nt":
+ __doctest_skip__ += ["SummarizationData.from_lists", "SummarizationData.from_json"]
class SummarizationData(DataModule):
@@ -120,7 +126,7 @@ def from_csv(
... predict_file="predict_data.csv",
... batch_size=2,
... )
- >>> model = SummarizationTask()
+ >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
@@ -173,7 +179,7 @@ def from_csv(
... predict_file="predict_data.tsv",
... batch_size=2,
... )
- >>> model = SummarizationTask()
+ >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
@@ -282,7 +288,7 @@ def from_json(
... batch_size=2,
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Downloading...
- >>> model = SummarizationTask()
+ >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
@@ -375,7 +381,7 @@ def from_hf_datasets(
... predict_hf_dataset=predict_data,
... batch_size=2,
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
- >>> model = SummarizationTask()
+ >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
@@ -454,7 +460,7 @@ def from_lists(
... predict_data=["A movie review", "A book chapter"],
... batch_size=2,
... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
- >>> model = SummarizationTask()
+ >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
diff --git a/src/flash/text/seq2seq/translation/data.py b/src/flash/text/seq2seq/translation/data.py
index 12dced0bec..c919d6c909 100644
--- a/src/flash/text/seq2seq/translation/data.py
+++ b/src/flash/text/seq2seq/translation/data.py
@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, List, Optional, Type
+import torch
+
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform
@@ -28,7 +30,7 @@
Dataset = object
# Skip doctests if requirements aren't available
-if not _TOPIC_TEXT_AVAILABLE:
+if not _TOPIC_TEXT_AVAILABLE or not torch.cuda.is_available():
__doctest_skip__ = ["TranslationData", "TranslationData.*"]
diff --git a/src/flash/video/classification/data.py b/src/flash/video/classification/data.py
index 68351013e7..e55392b482 100644
--- a/src/flash/video/classification/data.py
+++ b/src/flash/video/classification/data.py
@@ -50,15 +50,9 @@
# Skip doctests if requirements aren't available
__doctest_skip__ = []
if not _TOPIC_VIDEO_AVAILABLE:
- __doctest_skip__ += [
- "VideoClassificationData",
- "VideoClassificationData.from_files",
- "VideoClassificationData.from_folders",
- "VideoClassificationData.from_data_frame",
- "VideoClassificationData.from_csv",
- "VideoClassificationData.from_tensors",
- "VideoClassificationData.from_fiftyone",
- ]
+ __doctest_skip__ += ["VideoClassificationData", "VideoClassificationData.*"]
+if not _FIFTYONE_AVAILABLE:
+ __doctest_skip__ += ["VideoClassificationData.from_fiftyone"]
class VideoClassificationData(DataModule):
@@ -1137,11 +1131,10 @@ def from_labelstudio(
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs,
) -> "VideoClassificationData":
- """Creates a :class:`~flash.core.data.data_module.DataModule` object
- from the given export file and data directory using the
- :class:`~flash.core.data.io.input.Input` of name
- :attr:`~flash.core.data.io.input.InputFormat.FOLDERS`
- from the passed or constructed :class:`~flash.core.data.io.input_transform.InputTransform`.
+ """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given export file and data
+ directory using the :class:`~flash.core.data.io.input.Input` of name
+ :attr:`~flash.core.data.io.input.InputFormat.FOLDERS` from the passed or constructed
+ :class:`~flash.core.data.io.input_transform.InputTransform`.
Args:
export_json: path to label studio export file
diff --git a/src/flash/video/classification/model.py b/src/flash/video/classification/model.py
index 578312cad7..f324474e0d 100644
--- a/src/flash/video/classification/model.py
+++ b/src/flash/video/classification/model.py
@@ -62,6 +62,7 @@ class VideoClassifier(ClassificationTask):
head: either a `nn.Module` or a callable function that converts the features extrated from the backbone
into class log probabilities (assuming default loss function). If `None`, will default to using
a single linear layer.
+
"""
backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES
diff --git a/src/flash/video/classification/utils.py b/src/flash/video/classification/utils.py
index 1c8cd2526e..90951ee06b 100644
--- a/src/flash/video/classification/utils.py
+++ b/src/flash/video/classification/utils.py
@@ -50,6 +50,7 @@ def __next__(self) -> dict:
'video_label':
'video_index': ,
}
+
"""
if not self._video_sampler_iter:
# Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py
index c1adaad03b..31e27e1f2f 100644
--- a/tests/audio/classification/test_data.py
+++ b/tests/audio/classification/test_data.py
@@ -14,14 +14,13 @@
from pathlib import Path
from typing import Tuple
+import flash
import numpy as np
import pytest
import torch
-from pytorch_lightning import seed_everything
-
-import flash
from flash.audio import AudioClassificationData
from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TOPIC_AUDIO_AVAILABLE
+from pytorch_lightning import seed_everything
if _PIL_AVAILABLE:
from PIL import Image
diff --git a/tests/audio/classification/test_model.py b/tests/audio/classification/test_model.py
index 5b164dc912..a499efba74 100644
--- a/tests/audio/classification/test_model.py
+++ b/tests/audio/classification/test_model.py
@@ -15,7 +15,6 @@
from unittest.mock import patch
import pytest
-
from flash.__main__ import main
from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py
index 37ffa918f8..7e20e38e3f 100644
--- a/tests/audio/speech_recognition/test_data.py
+++ b/tests/audio/speech_recognition/test_data.py
@@ -15,9 +15,8 @@
import os
from pathlib import Path
-import pytest
-
import flash
+import pytest
from flash.audio import SpeechRecognitionData
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE
diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py
index 6cedcaad59..e2fc574e24 100644
--- a/tests/audio/speech_recognition/test_data_model_integration.py
+++ b/tests/audio/speech_recognition/test_data_model_integration.py
@@ -15,9 +15,8 @@
import os
from pathlib import Path
-import pytest
-
import flash
+import pytest
from flash import Trainer
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE
diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py
index 29cc516bcf..61e71dcf6e 100644
--- a/tests/audio/speech_recognition/test_model.py
+++ b/tests/audio/speech_recognition/test_model.py
@@ -17,11 +17,11 @@
import numpy as np
import pytest
import torch
-from torch import Tensor
-
from flash.audio import SpeechRecognition
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TOPIC_SERVE_AVAILABLE
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # tiny model for testing
diff --git a/tests/conftest.py b/tests/conftest.py
index edf66a29e9..9aa6bb6cf0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -4,10 +4,9 @@
import pytest
import torch
-from pytest_mock import MockerFixture
-
from flash.core.serve.decorators import uuid4 # noqa (used in mocker.patch)
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TORCHVISION_AVAILABLE
+from pytest_mock import MockerFixture
if _TORCHVISION_AVAILABLE:
import torchvision
diff --git a/tests/core/data/io/test_input.py b/tests/core/data/io/test_input.py
index fc668ef64d..c29f6bc9ed 100644
--- a/tests/core/data/io/test_input.py
+++ b/tests/core/data/io/test_input.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.data.io.input import Input, IterableInput, ServeInput
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
from flash.core.utilities.stages import RunningStage
diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py
index 1f85c7eb42..0273481f37 100644
--- a/tests/core/data/io/test_output.py
+++ b/tests/core/data/io/test_output.py
@@ -14,7 +14,6 @@
from unittest.mock import Mock
import pytest
-
from flash.core.data.io.output import Output
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/data/io/test_output_transform.py b/tests/core/data/io/test_output_transform.py
index d84907e674..16ab8c821f 100644
--- a/tests/core/data/io/test_output_transform.py
+++ b/tests/core/data/io/test_output_transform.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import torch
-
from flash.core.data.io.output_transform import OutputTransform
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py
index 954750216f..2fbc96ac4f 100644
--- a/tests/core/data/test_base_viz.py
+++ b/tests/core/data/test_base_viz.py
@@ -17,15 +17,14 @@
import numpy as np
import pytest
import torch
-from pytorch_lightning import seed_everything
-from torch import Tensor
-
from flash.core.data.base_viz import BaseVisualization
from flash.core.data.io.input import DataKeys
from flash.core.data.utils import _CALLBACK_FUNCS
from flash.core.utilities.imports import _PIL_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.image import ImageClassificationData
+from pytorch_lightning import seed_everything
+from torch import Tensor
if _PIL_AVAILABLE:
from PIL import Image
diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py
index f9efad9091..32f12f4568 100644
--- a/tests/core/data/test_batch.py
+++ b/tests/core/data/test_batch.py
@@ -15,7 +15,6 @@
import pytest
import torch
-
from flash.core.data.batch import default_uncollate
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py
index 6b3d195945..45b44d06a7 100644
--- a/tests/core/data/test_callback.py
+++ b/tests/core/data/test_callback.py
@@ -15,7 +15,6 @@
import pytest
import torch
-
from flash import DataKeys
from flash.core.data.data_module import DataModule, DatasetInput
from flash.core.data.io.input_transform import InputTransform
diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py
index f31a3ce82d..41e71210f2 100644
--- a/tests/core/data/test_callbacks.py
+++ b/tests/core/data/test_callbacks.py
@@ -15,14 +15,13 @@
import pytest
import torch
-from torch import tensor
-
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
from flash.core.utilities.stages import RunningStage
+from torch import tensor
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py
index 76f443ac91..e36f1136e5 100644
--- a/tests/core/data/test_data_module.py
+++ b/tests/core/data/test_data_module.py
@@ -18,15 +18,15 @@
import numpy as np
import pytest
import torch
-from pytorch_lightning import seed_everything
-from torch.utils.data import Dataset
-
from flash import Task, Trainer
from flash.core.data.data_module import DataModule, DatasetInput
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage
+from pytorch_lightning import seed_everything
+from torch.utils.data import Dataset
+
from tests.helpers.boring_model import BoringModel
if _TORCHVISION_AVAILABLE:
diff --git a/tests/core/data/test_input_transform.py b/tests/core/data/test_input_transform.py
index 7394f62ece..4cdfd75b9a 100644
--- a/tests/core/data/test_input_transform.py
+++ b/tests/core/data/test_input_transform.py
@@ -14,7 +14,6 @@
from typing import Callable
import pytest
-
from flash.core.data.io.input_transform import InputTransform
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
from flash.core.utilities.stages import RunningStage
diff --git a/tests/core/data/test_properties.py b/tests/core/data/test_properties.py
index 82d781d887..c851c7cf14 100644
--- a/tests/core/data/test_properties.py
+++ b/tests/core/data/test_properties.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.data.properties import Properties
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
from flash.core.utilities.stages import RunningStage
diff --git a/tests/core/data/test_splits.py b/tests/core/data/test_splits.py
index ea670135c3..a3f1ac3da6 100644
--- a/tests/core/data/test_splits.py
+++ b/tests/core/data/test_splits.py
@@ -15,7 +15,6 @@
import numpy as np
import pytest
-
from flash.core.data.data_module import DataModule
from flash.core.data.splits import SplitDataset
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py
index 3667ac74ee..cf83c39354 100644
--- a/tests/core/data/test_transforms.py
+++ b/tests/core/data/test_transforms.py
@@ -15,7 +15,6 @@
import pytest
import torch
-
from flash.core.data.io.input import DataKeys
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/data/utilities/test_classification.py b/tests/core/data/utilities/test_classification.py
index 0bfb809e55..86d16f44e2 100644
--- a/tests/core/data/utilities/test_classification.py
+++ b/tests/core/data/utilities/test_classification.py
@@ -17,7 +17,6 @@
import numpy as np
import pytest
import torch
-
from flash.core.data.utilities.classification import (
CommaDelimitedMultiLabelTargetFormatter,
MultiBinaryTargetFormatter,
diff --git a/tests/core/data/utilities/test_loading.py b/tests/core/data/utilities/test_loading.py
index 3ea85e2d98..d5706b2011 100644
--- a/tests/core/data/utilities/test_loading.py
+++ b/tests/core/data/utilities/test_loading.py
@@ -15,7 +15,6 @@
import numpy as np
import pytest
-
from flash.core.data.utilities.loading import (
AUDIO_EXTENSIONS,
CSV_EXTENSIONS,
diff --git a/tests/core/data/utilities/test_paths.py b/tests/core/data/utilities/test_paths.py
index ebfc649b05..14a827c311 100644
--- a/tests/core/data/utilities/test_paths.py
+++ b/tests/core/data/utilities/test_paths.py
@@ -17,10 +17,9 @@
from typing import List
import pytest
-from numpy import random
-
from flash.core.data.utilities.loading import AUDIO_EXTENSIONS, IMG_EXTENSIONS, NP_EXTENSIONS
from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files
+from numpy import random
def _make_mock_dir(root, mock_files: List) -> List[PATH_TYPE]:
diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py
index 84d1a19576..e32460373a 100644
--- a/tests/core/integrations/labelstudio/test_labelstudio.py
+++ b/tests/core/integrations/labelstudio/test_labelstudio.py
@@ -1,5 +1,4 @@
import pytest
-
from flash.core.data.utils import download_data
from flash.core.integrations.labelstudio.input import (
LabelStudioImageClassificationInput,
diff --git a/tests/core/integrations/vissl/test_strategies.py b/tests/core/integrations/vissl/test_strategies.py
index 2a6faf9cfe..4b517dc80d 100644
--- a/tests/core/integrations/vissl/test_strategies.py
+++ b/tests/core/integrations/vissl/test_strategies.py
@@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE
from flash.image.embedding.heads.vissl_heads import SimCLRHead
from flash.image.embedding.vissl.hooks import TrainingSetupHook
if _VISSL_AVAILABLE:
+ from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
from vissl.hooks.swav_hooks import NormalizePrototypesHook, SwAVUpdateQueueScoresHook
from vissl.losses.barlow_twins_loss import BarlowTwinsLoss
from vissl.losses.swav_loss import SwAVLoss
from vissl.models.heads.swav_prototypes_head import SwAVPrototypesHead
-
- from flash.image.embedding.strategies import IMAGE_EMBEDDER_STRATEGIES
else:
NormalizePrototypesHook = object
SwAVUpdateQueueScoresHook = object
diff --git a/tests/core/optimizers/test_lr_scheduler.py b/tests/core/optimizers/test_lr_scheduler.py
index 703fc851aa..317fd0e6be 100644
--- a/tests/core/optimizers/test_lr_scheduler.py
+++ b/tests/core/optimizers/test_lr_scheduler.py
@@ -14,11 +14,10 @@
import math
import pytest
-from torch import nn
-from torch.optim import Adam
-
from flash.core.optimizers import LinearWarmupCosineAnnealingLR
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
+from torch import nn
+from torch.optim import Adam
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
diff --git a/tests/core/optimizers/test_optimizers.py b/tests/core/optimizers/test_optimizers.py
index 51b82233b2..09708552d9 100644
--- a/tests/core/optimizers/test_optimizers.py
+++ b/tests/core/optimizers/test_optimizers.py
@@ -13,10 +13,9 @@
# limitations under the License.
import pytest
import torch
-from torch import nn
-
from flash.core.optimizers import LAMB, LARS, LinearWarmupCosineAnnealingLR
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
+from torch import nn
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
diff --git a/tests/core/serve/models.py b/tests/core/serve/models.py
index 8b907e932d..3ddfe0421d 100644
--- a/tests/core/serve/models.py
+++ b/tests/core/serve/models.py
@@ -2,11 +2,10 @@
import pytorch_lightning as pl
import torch
-from torch import Tensor
-
from flash.core.serve import ModelComponent, expose
from flash.core.serve.types import Image, Label, Number, Repeated
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
+from torch import Tensor
if _TORCHVISION_AVAILABLE:
from torchvision.models import squeezenet1_1
diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py
index c77ea49457..52e5843eea 100644
--- a/tests/core/serve/test_components.py
+++ b/tests/core/serve/test_components.py
@@ -1,8 +1,8 @@
import pytest
import torch
-
from flash.core.serve.types import Label
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
+
from tests.core.serve.models import ClassificationInferenceComposable, LightningSqueezenet
diff --git a/tests/core/serve/test_composition.py b/tests/core/serve/test_composition.py
index 1152f2c69e..c50f565b48 100644
--- a/tests/core/serve/test_composition.py
+++ b/tests/core/serve/test_composition.py
@@ -2,7 +2,6 @@
from dataclasses import asdict
import pytest
-
from flash.core.serve import Composition, Endpoint
from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py
index 12b25d73f1..d67a81dcb6 100644
--- a/tests/core/serve/test_dag/test_optimization.py
+++ b/tests/core/serve/test_dag/test_optimization.py
@@ -3,7 +3,6 @@
from functools import partial
import pytest
-
from flash.core.serve.dag.optimization import (
SubgraphCallable,
cull,
diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py
index 3ede55d64b..b22c81a8c5 100644
--- a/tests/core/serve/test_dag/test_order.py
+++ b/tests/core/serve/test_dag/test_order.py
@@ -1,5 +1,4 @@
import pytest
-
from flash.core.serve.dag.order import ndependencies, order
from flash.core.serve.dag.task import get, get_deps
from flash.core.serve.dag.utils_test import add, inc
@@ -48,6 +47,7 @@ def test_avoid_broker_nodes(abcde):
a0 a1
a0 should be run before a1
+
"""
a, b, c, d, e = abcde
dsk = {
@@ -101,6 +101,7 @@ def test_base_of_reduce_preferred(abcde):
c
We really want to run b0 quickly
+
"""
a, b, c, d, e = abcde
dsk = {(a, i): (f, (a, i - 1), (b, i)) for i in [1, 2, 3]}
@@ -202,6 +203,7 @@ def test_deep_bases_win_over_dependents(abcde):
b c |
/ \ | /
e d
+
"""
a, b, c, d, e = abcde
dsk = {a: (f, b, c, d), b: (f, d, e), c: (f, d), d: 1, e: 2}
@@ -298,6 +300,7 @@ def test_run_smaller_sections(abcde):
a c e cc
Prefer to run acb first because then we can get that out of the way
+
"""
a, b, c, d, e = abcde
aa, bb, cc, dd = (x * 2 for x in [a, b, c, d])
@@ -390,6 +393,7 @@ def test_nearest_neighbor(abcde):
Want to finish off a local group before moving on.
This is difficult because all groups are connected.
+
"""
a, b, c, _, _ = abcde
a1, a2, a3, a4, a5, a6, a7, a8, a9 = (a + i for i in "123456789")
@@ -528,6 +532,7 @@ def test_map_overlap(abcde):
e1 e2 e5
Want to finish b1 before we start on e5
+
"""
a, b, c, d, e = abcde
dsk = {
@@ -698,6 +703,7 @@ def test_switching_dependents(abcde):
This test is pretty specific to how `order` is implemented
and is intended to increase code coverage.
+
"""
a, b, c, d, e = abcde
dsk = {
diff --git a/tests/core/serve/test_dag/test_rewrite.py b/tests/core/serve/test_dag/test_rewrite.py
index 1e1bfa8a35..1c321992bf 100644
--- a/tests/core/serve/test_dag/test_rewrite.py
+++ b/tests/core/serve/test_dag/test_rewrite.py
@@ -1,5 +1,4 @@
import pytest
-
from flash.core.serve.dag.rewrite import VAR, RewriteRule, RuleSet, Traverser, args, head
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_dag/test_task.py b/tests/core/serve/test_dag/test_task.py
index 49c69ade50..002b41d884 100644
--- a/tests/core/serve/test_dag/test_task.py
+++ b/tests/core/serve/test_dag/test_task.py
@@ -2,7 +2,6 @@
from collections import namedtuple
import pytest
-
from flash.core.serve.dag.task import (
flatten,
get,
@@ -115,6 +114,7 @@ def test_get_dependencies_task_none():
@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.")
def test_get_deps():
"""
+ >>> from flash.core.serve.dag.utils_test import inc
>>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')}
>>> dependencies, dependents = get_deps(dsk)
>>> dependencies
diff --git a/tests/core/serve/test_dag/test_utils.py b/tests/core/serve/test_dag/test_utils.py
index f3faf47606..ded5f170d7 100644
--- a/tests/core/serve/test_dag/test_utils.py
+++ b/tests/core/serve/test_dag/test_utils.py
@@ -3,7 +3,6 @@
import numpy as np
import pytest
-
from flash.core.serve.dag.utils import funcname, partial_by_order
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py
index c343d95622..6ae23f7a80 100644
--- a/tests/core/serve/test_gridbase_validations.py
+++ b/tests/core/serve/test_gridbase_validations.py
@@ -1,5 +1,4 @@
import pytest
-
from flash.core.serve import ModelComponent, expose
from flash.core.serve.types import Number
from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _TOPIC_SERVE_AVAILABLE
@@ -176,6 +175,7 @@ def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_metho
This is noted because it differs from some other metaclass validations which will raise an exception at class
definition time.
+
"""
from tests.core.serve.models import ClassificationInference
@@ -199,6 +199,7 @@ def test_ModelComponent_raises_if_config_is_empty_dict(lightning_squeezenet1_1_o
This is noted because it differs from some other metaclass validations which will raise an exception at class
definition time.
+
"""
class ConfigComponent(ModelComponent):
@@ -219,6 +220,7 @@ def test_ModelComponent_raises_if_model_is_empty_iterable():
This is noted because it differs from some other metaclass validations which will raise an exception at class
definition time.
+
"""
class ConfigComponent(ModelComponent):
diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py
index d54d2dab60..b96597d4e3 100644
--- a/tests/core/serve/test_integration.py
+++ b/tests/core/serve/test_integration.py
@@ -1,7 +1,6 @@
import base64
import pytest
-
from flash.core.serve import Composition, Endpoint
from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_types/test_bbox.py b/tests/core/serve/test_types/test_bbox.py
index 3fe9e273b6..e0b5fe7c77 100644
--- a/tests/core/serve/test_types/test_bbox.py
+++ b/tests/core/serve/test_types/test_bbox.py
@@ -1,6 +1,5 @@
import pytest
import torch
-
from flash.core.serve.types import BBox
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_types/test_image.py b/tests/core/serve/test_types/test_image.py
index d96dc7671a..9689c4b9c0 100644
--- a/tests/core/serve/test_types/test_image.py
+++ b/tests/core/serve/test_types/test_image.py
@@ -2,10 +2,9 @@
import numpy as np
import pytest
-from torch import Tensor
-
from flash.core.serve.types import Image
from flash.core.utilities.imports import _PIL_AVAILABLE, _TOPIC_SERVE_AVAILABLE
+from torch import Tensor
@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.")
diff --git a/tests/core/serve/test_types/test_label.py b/tests/core/serve/test_types/test_label.py
index 72c037e6fb..5ef8794d60 100644
--- a/tests/core/serve/test_types/test_label.py
+++ b/tests/core/serve/test_types/test_label.py
@@ -1,6 +1,5 @@
import pytest
import torch
-
from flash.core.serve.types import Label
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_types/test_number.py b/tests/core/serve/test_types/test_number.py
index deefd8477d..b640fb9663 100644
--- a/tests/core/serve/test_types/test_number.py
+++ b/tests/core/serve/test_types/test_number.py
@@ -1,6 +1,5 @@
import pytest
import torch
-
from flash.core.serve.types import Number
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_types/test_repeated.py b/tests/core/serve/test_types/test_repeated.py
index a7d7b035b8..a23913bf89 100644
--- a/tests/core/serve/test_types/test_repeated.py
+++ b/tests/core/serve/test_types/test_repeated.py
@@ -1,6 +1,5 @@
import pytest
import torch
-
from flash.core.serve.types import Label, Repeated
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_types/test_table.py b/tests/core/serve/test_types/test_table.py
index 78eb96d1bf..a27c55a279 100644
--- a/tests/core/serve/test_types/test_table.py
+++ b/tests/core/serve/test_types/test_table.py
@@ -1,6 +1,5 @@
import pytest
import torch
-
from flash.core.serve.types import Table
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/serve/test_types/test_text.py b/tests/core/serve/test_types/test_text.py
index 3a813b05db..0c0aebb3f7 100644
--- a/tests/core/serve/test_types/test_text.py
+++ b/tests/core/serve/test_types/test_text.py
@@ -2,7 +2,6 @@
import pytest
import torch
-
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE
diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py
index dff8750cd2..f27310fc89 100644
--- a/tests/core/test_classification.py
+++ b/tests/core/test_classification.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import torch
-
from flash.core.classification import (
ClassesOutput,
FiftyOneLabelsOutput,
diff --git a/tests/core/test_data.py b/tests/core/test_data.py
index de7513713d..a9bff42137 100644
--- a/tests/core/test_data.py
+++ b/tests/core/test_data.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import torch
-
from flash import DataKeys, DataModule, RunningStage
from flash.core.data.data_module import DatasetInput
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py
index e49a1817f5..4d7a093936 100644
--- a/tests/core/test_finetuning.py
+++ b/tests/core/test_finetuning.py
@@ -14,19 +14,19 @@
from numbers import Number
from typing import Iterable, List, Optional, Tuple, Union
+import flash
import pytest
import pytorch_lightning as pl
import torch
+from flash.core.finetuning import _FINETUNING_STRATEGIES_REGISTRY
+from flash.core.model import Task
+from flash.core.utilities.imports import _DEEPSPEED_AVAILABLE, _TOPIC_CORE_AVAILABLE
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor
from torch.nn import Flatten, Linear, LogSoftmax, Module
from torch.nn import functional as F
from torch.utils.data import DataLoader
-import flash
-from flash.core.finetuning import _FINETUNING_STRATEGIES_REGISTRY
-from flash.core.model import Task
-from flash.core.utilities.imports import _DEEPSPEED_AVAILABLE, _TOPIC_CORE_AVAILABLE
from tests.helpers.boring_model import BoringModel
diff --git a/tests/core/test_model.py b/tests/core/test_model.py
index 909f126d60..3d9ced8ae6 100644
--- a/tests/core/test_model.py
+++ b/tests/core/test_model.py
@@ -19,17 +19,10 @@
from typing import Any, Tuple
from unittest.mock import ANY, MagicMock
+import flash
import pytest
import pytorch_lightning as pl
import torch
-from pytorch_lightning import LightningDataModule
-from pytorch_lightning.callbacks import Callback
-from torch import Tensor, nn
-from torch.nn import functional as F
-from torch.utils.data import DataLoader
-from torchmetrics import Accuracy
-
-import flash
from flash import Task
from flash.audio import SpeechRecognition
from flash.core.adapter import Adapter
@@ -54,6 +47,12 @@
from flash.image import ImageClassifier, SemanticSegmentation
from flash.tabular import TabularClassifier
from flash.text import SummarizationTask, TextClassifier, TranslationTask
+from pytorch_lightning import LightningDataModule
+from pytorch_lightning.callbacks import Callback
+from torch import Tensor, nn
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from torchmetrics import Accuracy
# ======== Mock functions ========
diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py
index 031b2dbfb9..457b3ab6cd 100644
--- a/tests/core/test_registry.py
+++ b/tests/core/test_registry.py
@@ -14,10 +14,9 @@
import logging
import pytest
-from torch import nn
-
from flash.core.registry import ConcatRegistry, ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
+from torch import nn
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py
index 695185e372..aa6668dddd 100644
--- a/tests/core/test_trainer.py
+++ b/tests/core/test_trainer.py
@@ -16,6 +16,9 @@
import pytest
import torch
+from flash import Trainer
+from flash.core.classification import ClassificationTask
+from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.core.lightning import LightningModule
@@ -24,10 +27,6 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
-from flash import Trainer
-from flash.core.classification import ClassificationTask
-from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
-
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, predict: bool = False):
diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py
index 8bc2abc8d1..3d2f707c20 100644
--- a/tests/core/test_utils.py
+++ b/tests/core/test_utils.py
@@ -14,7 +14,6 @@
import os
import pytest
-
from flash.core.data.utils import download_data
from flash.core.utilities.apply_func import get_callable_dict, get_callable_name
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
diff --git a/tests/core/utilities/test_embedder.py b/tests/core/utilities/test_embedder.py
index 5f4c297d94..20f6792b29 100644
--- a/tests/core/utilities/test_embedder.py
+++ b/tests/core/utilities/test_embedder.py
@@ -15,11 +15,10 @@
import pytest
import torch
-from pytorch_lightning import LightningModule
-from torch import nn
-
from flash.core.utilities.embedder import Embedder
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
+from pytorch_lightning import LightningModule
+from torch import nn
class EmbedderTestModel(LightningModule):
@@ -55,6 +54,7 @@ def test_embedder(layer, size):
assert embedder(torch.rand(10, 10)).size(1) == size
+@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
def test_embedder_scaling_overhead():
"""Tests that embedding to the 3rd layer of a 200 layer model takes less than double the time of embedding to.
@@ -63,6 +63,7 @@ def test_embedder_scaling_overhead():
200 layer model.
Note that this bound is intentionally high in an effort to reduce the flakiness of the test.
+
"""
shallow_embedder = Embedder(NLayerModel(3), "backbone.2")
@@ -84,12 +85,14 @@ def test_embedder_scaling_overhead():
assert (diff_time / shallow_time) < 2
+@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
def test_embedder_raising_overhead():
"""Tests that embedding to the output layer of a 3 layer model takes less than 10ms more than the time taken to
execute the model without the embedder.
Note that this bound is intentionally high in an effort to reduce the flakiness of the test.
+
"""
model = NLayerModel(10)
embedder = Embedder(model, "output")
diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py
index f8832547f0..44b9a9c559 100644
--- a/tests/core/utilities/test_lightning_cli.py
+++ b/tests/core/utilities/test_lightning_cli.py
@@ -13,12 +13,6 @@
import pytest
import torch
import yaml
-from packaging import version
-from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
-from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
-from pytorch_lightning.plugins.environments import SLURMEnvironment
-from torch import nn
-
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.lightning_cli import (
@@ -27,6 +21,12 @@
SaveConfigCallback,
instantiate_class,
)
+from packaging import version
+from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
+from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
+from pytorch_lightning.plugins.environments import SLURMEnvironment
+from torch import nn
+
from tests.helpers.boring_model import BoringDataModule, BoringModel
torchvision_version = version.parse("0")
@@ -40,7 +40,7 @@ def test_default_args(mock_argparse, tmpdir):
"""Tests default argument parser for Trainer."""
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
- parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser = LightningArgumentParser(add_help=False)
args = parser.parse_args([])
args.max_epochs = 5
@@ -54,7 +54,7 @@ def test_default_args(mock_argparse, tmpdir):
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--default_root_dir=./"], []])
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
- parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args(cli_args)
@@ -79,11 +79,11 @@ def test_add_argparse_args_redefined(cli_args):
("--auto_lr_find=True --auto_scale_batch_size=power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}),
(
"--auto_lr_find any_string --auto_scale_batch_size ON",
- {"auto_lr_find": "any_string", "auto_scale_batch_size": True},
+ {"auto_lr_find": "any_string", "auto_scale_batch_size": "ON"},
),
- ("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": True}),
- ("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": False}),
- ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}),
+ ("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": "On"}),
+ ("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": "No"}),
+ ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": "FALSE"}),
("--limit_train_batches=100", {"limit_train_batches": 100}),
("--limit_train_batches 0.8", {"limit_train_batches": 0.8}),
],
@@ -91,7 +91,7 @@ def test_add_argparse_args_redefined(cli_args):
def test_parse_args_parsing(cli_args, expected):
"""Test parsing simple types and None optionals not modified."""
cli_args = cli_args.split(" ") if cli_args else []
- parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
@@ -112,7 +112,7 @@ def test_parse_args_parsing(cli_args, expected):
)
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
"""Test parsing complex types."""
- parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
@@ -137,7 +137,7 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu):
"""Test parsing of gpus and instantiation of Trainer."""
mocker.patch("lightning_lite.utilities.device_parser._get_all_available_gpus", return_value=[0, 1])
cli_args = cli_args.split(" ") if cli_args else []
- parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
with patch("sys.argv", ["any.py"] + cli_args):
args = parser.parse_args()
@@ -310,8 +310,8 @@ def test_lightning_cli_args(tmpdir):
config = yaml.safe_load(f.read())
assert "model" not in config
assert "model" not in cli.config
- assert config["data"] == cli.config["data"]
- assert config["trainer"] == cli.config["trainer"]
+ assert config["data"] == cli.config["data"].as_dict()
+ assert config["trainer"] == cli.config["trainer"].as_dict()
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
@@ -363,9 +363,9 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
assert os.path.isfile(config_path)
with open(config_path) as f:
config = yaml.safe_load(f.read())
- assert config["model"] == cli.config["model"]
- assert config["data"] == cli.config["data"]
- assert config["trainer"] == cli.config["trainer"]
+ assert config["model"] == cli.config["model"].as_dict()
+ assert config["data"] == cli.config["data"].as_dict()
+ assert config["trainer"] == cli.config["trainer"].as_dict()
def any_model_any_data_cli():
@@ -578,16 +578,19 @@ def on_exception(self, execption):
raise execption
+@pytest.mark.skipif(os.name == "nt", reason="Strange DDP values, need to debug later...") # todo
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
@pytest.mark.parametrize("logger", [False, True])
@pytest.mark.parametrize(
"trainer_kwargs",
[
{"accelerator": "cpu", "strategy": "ddp"},
- {"accelerator": "cpu", "strategy": "ddp", "plugins": "ddp_find_unused_parameters_false"},
+ pytest.param(
+ {"accelerator": "cpu", "strategy": "ddp", "plugins": "ddp_find_unused_parameters_false"},
+ marks=pytest.mark.xfail(reason="Bugs in PL >= 1.6.0"),
+ ),
],
)
-@pytest.mark.xfail(reason="Bugs in PL >= 1.6.0")
def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs):
with patch("sys.argv", ["any.py"]), pytest.raises(CustomException):
LightningCLI(
diff --git a/tests/core/utilities/test_stability.py b/tests/core/utilities/test_stability.py
index 16916be03f..011a1aa31c 100644
--- a/tests/core/utilities/test_stability.py
+++ b/tests/core/utilities/test_stability.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE
from flash.core.utilities.stability import _raise_beta_warning, beta
diff --git a/tests/deprecated_api/test_remove_0_9_0.py b/tests/deprecated_api/test_remove_0_9_0.py
index 363a0d788b..88ff568942 100644
--- a/tests/deprecated_api/test_remove_0_9_0.py
+++ b/tests/deprecated_api/test_remove_0_9_0.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.utilities.imports import _VISSL_AVAILABLE
from flash.image.embedding.model import ImageEmbedder
diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py
index 745912852a..d275af2711 100644
--- a/tests/examples/test_scripts.py
+++ b/tests/examples/test_scripts.py
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
+import shutil
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
-
from flash.core.utilities.imports import (
_BAAL_AVAILABLE,
_FIFTYONE_AVAILABLE,
@@ -37,6 +37,7 @@
_TORCHVISION_GREATER_EQUAL_0_9,
_VISSL_AVAILABLE,
)
+
from tests.examples.helpers import run_test
from tests.helpers.decorators import forked
@@ -59,7 +60,10 @@
pytest.param(
"audio",
"audio_classification.py",
- marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed"),
+ marks=[
+ pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed"),
+ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"),
+ ],
),
pytest.param(
"audio",
@@ -235,3 +239,11 @@
@pytest.mark.skipif(sys.platform == "darwin", reason="Fatal Python error: Illegal instruction") # fixme
def test_example(folder, fname):
run_test(str(root / "examples" / folder / fname))
+
+ # clean ALL logs and used data
+ shutil.rmtree(root / "data", ignore_errors=True)
+ shutil.rmtree(root / "lightning_logs", ignore_errors=True)
+
+ # remove all saved models
+ for p in root.glob("*.pt"):
+ p.unlink()
diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py
index 0d35fb232a..058dd63dcb 100644
--- a/tests/graph/classification/test_data.py
+++ b/tests/graph/classification/test_data.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash import DataKeys
from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.graph.classification.data import GraphClassificationData
diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py
index 9168ed1602..496d345d76 100644
--- a/tests/graph/classification/test_model.py
+++ b/tests/graph/classification/test_model.py
@@ -15,8 +15,6 @@
import pytest
import torch
-from torch import Tensor
-
from flash import RunningStage, Trainer
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import DataKeys
@@ -24,6 +22,8 @@
from flash.graph.classification import GraphClassifier
from flash.graph.classification.input import GraphClassificationDatasetInput
from flash.graph.classification.input_transform import GraphClassificationInputTransform
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
if _TOPIC_GRAPH_AVAILABLE:
diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py
index 78bf198a0f..1143bc5a05 100644
--- a/tests/graph/embedding/test_model.py
+++ b/tests/graph/embedding/test_model.py
@@ -15,8 +15,6 @@
import pytest
import torch
-from torch import Tensor
-
from flash import RunningStage, Trainer
from flash.core.data.data_module import DataModule
from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE
@@ -24,6 +22,8 @@
from flash.graph.classification.input_transform import GraphClassificationInputTransform
from flash.graph.classification.model import GraphClassifier
from flash.graph.embedding.model import GraphEmbedder
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
if _TOPIC_GRAPH_AVAILABLE:
diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py
index 96505cde7a..14e5f51d9b 100644
--- a/tests/helpers/boring_model.py
+++ b/tests/helpers/boring_model.py
@@ -33,6 +33,7 @@ def training_step(...):
model = BaseTestModel()
model.training_epoch_end = None
+
"""
super().__init__()
self.layer = torch.nn.Linear(32, 2)
diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py
index 97dbefb571..c480c1641f 100644
--- a/tests/helpers/task_tester.py
+++ b/tests/helpers/task_tester.py
@@ -20,13 +20,12 @@
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch
+import flash
import pytest
import torch
-from torch.utils.data import Dataset
-
-import flash
from flash.__main__ import main
from flash.core.model import Task
+from torch.utils.data import Dataset
def _copy_func(f):
diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py
index 6ad08932b1..eec23ba446 100644
--- a/tests/image/classification/test_active_learning.py
+++ b/tests/image/classification/test_active_learning.py
@@ -14,17 +14,17 @@
import math
from pathlib import Path
+import flash
import numpy as np
import pytest
import torch
+from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _TOPIC_IMAGE_AVAILABLE
+from flash.image import ImageClassificationData, ImageClassifier
+from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
from pytorch_lightning import seed_everything
from torch import nn
from torch.utils.data import SequentialSampler
-import flash
-from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _TOPIC_IMAGE_AVAILABLE
-from flash.image import ImageClassificationData, ImageClassifier
-from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
from tests.image.classification.test_data import _rand_image
# ======== Mock functions ========
diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py
index 1aace9a033..fdfc4f1ae8 100644
--- a/tests/image/classification/test_data.py
+++ b/tests/image/classification/test_data.py
@@ -21,7 +21,6 @@
import pandas as pd
import pytest
import torch
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py
index 8d8eb9aa5f..08c8cf338a 100644
--- a/tests/image/classification/test_data_model_integration.py
+++ b/tests/image/classification/test_data_model_integration.py
@@ -15,7 +15,6 @@
import numpy as np
import pytest
-
from flash import Trainer
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image import ImageClassificationData, ImageClassifier
diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py
index 8de7a109ac..a865063183 100644
--- a/tests/image/classification/test_model.py
+++ b/tests/image/classification/test_model.py
@@ -16,12 +16,12 @@
import pytest
import torch
-from torch import Tensor
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE
from flash.image import ImageClassifier
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
# ======== Mock functions ========
diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py
index dd524dd722..c3c40a3b88 100644
--- a/tests/image/classification/test_training_strategies.py
+++ b/tests/image/classification/test_training_strategies.py
@@ -16,13 +16,13 @@
import pytest
import torch
-from torch.utils.data import DataLoader
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.adapters import TRAINING_STRATEGIES
+from torch.utils.data import DataLoader
+
from tests.image.classification.test_data import _rand_image
# ======== Mock functions ========
diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py
index 18703fe951..1dea17fc3b 100644
--- a/tests/image/detection/test_data.py
+++ b/tests/image/detection/test_data.py
@@ -16,7 +16,6 @@
from pathlib import Path
import pytest
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, _PIL_AVAILABLE
from flash.image.detection.data import ObjectDetectionData
diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py
index 4ebf18752f..7d4f642919 100644
--- a/tests/image/detection/test_data_model_integration.py
+++ b/tests/image/detection/test_data_model_integration.py
@@ -13,10 +13,9 @@
# limitations under the License.
import os
+import flash
import pytest
import torch
-
-import flash
from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _PIL_AVAILABLE
from flash.image import ObjectDetector
from flash.image.detection import ObjectDetectionData
diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py
index 9465b69df0..f01d04b89b 100644
--- a/tests/image/detection/test_model.py
+++ b/tests/image/detection/test_model.py
@@ -18,8 +18,6 @@
import numpy as np
import pytest
import torch
-from torch.utils.data import Dataset
-
from flash.core.data.io.input import DataKeys
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.trainer import Trainer
@@ -30,6 +28,8 @@
_TOPIC_SERVE_AVAILABLE,
)
from flash.image import ObjectDetector
+from torch.utils.data import Dataset
+
from tests.helpers.task_tester import TaskTester
diff --git a/tests/image/detection/test_output.py b/tests/image/detection/test_output.py
index 7ad27f60d6..62efd82de0 100644
--- a/tests/image/detection/test_output.py
+++ b/tests/image/detection/test_output.py
@@ -1,7 +1,6 @@
import numpy as np
import pytest
import torch
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image.detection.output import FiftyOneDetectionLabelsOutput
diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py
index 7848fdd84f..a43f55db3e 100644
--- a/tests/image/embedding/test_model.py
+++ b/tests/image/embedding/test_model.py
@@ -13,13 +13,13 @@
# limitations under the License.
from typing import Any
+import flash
import pytest
import torch
-from torch import Tensor
-
-import flash
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE
from flash.image import ImageClassificationData, ImageEmbedder
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
if _TORCHVISION_AVAILABLE:
diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py
index a7da39c1dc..9dd5b75d17 100644
--- a/tests/image/face_detection/test_model.py
+++ b/tests/image/face_detection/test_model.py
@@ -14,9 +14,8 @@
import contextlib
from unittest.mock import patch
-import pytest
-
import flash
+import pytest
from flash.__main__ import main
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FASTFACE_AVAILABLE
@@ -25,7 +24,6 @@
if _FASTFACE_AVAILABLE:
import fastface as ff
from fastface.arch.lffd import LFFD
-
from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES
else:
FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones")
diff --git a/tests/image/instance_segm/test_data.py b/tests/image/instance_segm/test_data.py
index fe00336ba3..d2b657b394 100644
--- a/tests/image/instance_segm/test_data.py
+++ b/tests/image/instance_segm/test_data.py
@@ -14,11 +14,11 @@
import numpy as np
import pytest
import torch
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE
from flash.image.instance_segmentation import InstanceSegmentationData
from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform
+
from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset
diff --git a/tests/image/instance_segm/test_model.py b/tests/image/instance_segm/test_model.py
index ae9fb1059b..1660c2a8fc 100644
--- a/tests/image/instance_segm/test_model.py
+++ b/tests/image/instance_segm/test_model.py
@@ -19,11 +19,11 @@
import numpy as np
import pytest
import torch
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image import InstanceSegmentation, InstanceSegmentationData
+
from tests.helpers.task_tester import TaskTester
if _TOPIC_IMAGE_AVAILABLE:
diff --git a/tests/image/keypoint_detection/test_data.py b/tests/image/keypoint_detection/test_data.py
index 5de9b7b266..4cc43ea5a5 100644
--- a/tests/image/keypoint_detection/test_data.py
+++ b/tests/image/keypoint_detection/test_data.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE
from flash.image.keypoint_detection import KeypointDetectionData
+
from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset
diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py
index 02769e503b..18de8c6a2e 100644
--- a/tests/image/keypoint_detection/test_model.py
+++ b/tests/image/keypoint_detection/test_model.py
@@ -19,11 +19,11 @@
import numpy as np
import pytest
import torch
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image import KeypointDetectionData, KeypointDetector
+
from tests.helpers.task_tester import TaskTester
if _TOPIC_IMAGE_AVAILABLE:
diff --git a/tests/image/semantic_segm/test_backbones.py b/tests/image/semantic_segm/test_backbones.py
index ba05e83c7a..59d0dfcfe8 100644
--- a/tests/image/semantic_segm/test_backbones.py
+++ b/tests/image/semantic_segm/test_backbones.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
diff --git a/tests/image/semantic_segm/test_data.py b/tests/image/semantic_segm/test_data.py
index aae5129c90..b092bb422d 100644
--- a/tests/image/semantic_segm/test_data.py
+++ b/tests/image/semantic_segm/test_data.py
@@ -5,7 +5,6 @@
import numpy as np
import pytest
import torch
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
diff --git a/tests/image/semantic_segm/test_heads.py b/tests/image/semantic_segm/test_heads.py
index 19472860c7..1ee328cb2f 100644
--- a/tests/image/semantic_segm/test_heads.py
+++ b/tests/image/semantic_segm/test_heads.py
@@ -15,7 +15,6 @@
import pytest
import torch
-
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.image.segmentation import SemanticSegmentation
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
diff --git a/tests/image/semantic_segm/test_model.py b/tests/image/semantic_segm/test_model.py
index c58b0a1632..31209e37ff 100644
--- a/tests/image/semantic_segm/test_model.py
+++ b/tests/image/semantic_segm/test_model.py
@@ -17,13 +17,13 @@
import numpy as np
import pytest
import torch
-from torch import Tensor
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE
from flash.image import SemanticSegmentation
from flash.image.segmentation.data import SemanticSegmentationData
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
diff --git a/tests/image/semantic_segm/test_output.py b/tests/image/semantic_segm/test_output.py
index df410ccdff..c2593dc886 100644
--- a/tests/image/semantic_segm/test_output.py
+++ b/tests/image/semantic_segm/test_output.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import torch
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py
index 51fea84407..ec23e46fdd 100644
--- a/tests/image/style_transfer/test_model.py
+++ b/tests/image/style_transfer/test_model.py
@@ -16,15 +16,15 @@
import pytest
import torch
-from torch import Tensor
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE
from flash.image.style_transfer import StyleTransfer
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
-@pytest.mark.xfail(URLError, reason="Connection timed out for download.pystiche.org")
+@pytest.mark.xfail(URLError, reason="Connection timed out for download.pystiche.org", strict=False)
class TestStyleTransfer(TaskTester):
task = StyleTransfer
cli_command = "style_transfer"
diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py
index 5d82b004d0..7dbc510a74 100644
--- a/tests/image/test_backbones.py
+++ b/tests/image/test_backbones.py
@@ -14,7 +14,6 @@
import urllib.error
import pytest
-
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE
from flash.core.utilities.url_error import catch_url_error
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES
diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py
index 4515d17a50..58abdc7d39 100644
--- a/tests/pointcloud/detection/test_data.py
+++ b/tests/pointcloud/detection/test_data.py
@@ -15,13 +15,12 @@
import pytest
import torch
-from pytorch_lightning import seed_everything
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.data.utils import download_data
from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE
from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData
+from pytorch_lightning import seed_everything
if _TOPIC_POINTCLOUD_AVAILABLE:
from flash.pointcloud.detection.open3d_ml.backbones import ObjectDetectBatchCollator
diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py
index aa088928e5..fbc135cf0b 100644
--- a/tests/pointcloud/detection/test_model.py
+++ b/tests/pointcloud/detection/test_model.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE
from flash.pointcloud.detection import PointCloudObjectDetector
+
from tests.helpers.task_tester import TaskTester
diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py
index 3e2314ecaa..e007449816 100644
--- a/tests/pointcloud/segmentation/test_data.py
+++ b/tests/pointcloud/segmentation/test_data.py
@@ -15,13 +15,12 @@
import pytest
import torch
-from pytorch_lightning import seed_everything
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.data.utils import download_data
from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE
from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData
+from pytorch_lightning import seed_everything
@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed")
diff --git a/tests/pointcloud/segmentation/test_datasets.py b/tests/pointcloud/segmentation/test_datasets.py
index 6c09f9da95..b0a168cd02 100644
--- a/tests/pointcloud/segmentation/test_datasets.py
+++ b/tests/pointcloud/segmentation/test_datasets.py
@@ -14,7 +14,6 @@
from unittest.mock import patch
import pytest
-
from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE
from flash.pointcloud.segmentation.datasets import LyftDataset, SemanticKITTIDataset
diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py
index 60456475f3..322914172e 100644
--- a/tests/pointcloud/segmentation/test_model.py
+++ b/tests/pointcloud/segmentation/test_model.py
@@ -13,9 +13,9 @@
# limitations under the License.
import pytest
import torch
-
from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE
from flash.pointcloud.segmentation import PointCloudSegmentation
+
from tests.helpers.task_tester import TaskTester
diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py
index 6776ad35b7..8764a706ef 100644
--- a/tests/tabular/classification/test_data.py
+++ b/tests/tabular/classification/test_data.py
@@ -16,13 +16,11 @@
import numpy as np
import pytest
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE
if _TOPIC_TABULAR_AVAILABLE:
import pandas as pd
-
from flash.tabular import TabularClassificationData
from flash.tabular.classification.utils import _categorize, _compute_normalization, _generate_codes, _normalize
diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py
index f52f045830..166ecb78be 100644
--- a/tests/tabular/classification/test_data_model_integration.py
+++ b/tests/tabular/classification/test_data_model_integration.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import pytorch_lightning as pl
-
from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE
from flash.tabular import TabularClassificationData, TabularClassifier
@@ -39,8 +38,7 @@
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
- # ("category_embedding", # todo: seems to be bug in tabular
- # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
+ ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py
index 008a797a99..0ca28f15c5 100644
--- a/tests/tabular/classification/test_model.py
+++ b/tests/tabular/classification/test_model.py
@@ -14,16 +14,16 @@
from typing import Any
from unittest.mock import patch
+import flash
import pandas as pd
import pytest
import torch
-from torch import Tensor
-
-import flash
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TABULAR_AVAILABLE
from flash.tabular.classification.data import TabularClassificationData
from flash.tabular.classification.model import TabularClassifier
+from torch import Tensor
+
from tests.helpers.task_tester import StaticDataset, TaskTester
@@ -55,7 +55,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
- # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
+ {"backbone": "category_embedding"},
],
)
],
@@ -68,7 +68,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
- # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
+ {"backbone": "category_embedding"},
],
)
],
@@ -81,7 +81,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
- # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
+ {"backbone": "category_embedding"},
],
)
],
diff --git a/tests/tabular/forecasting/test_data.py b/tests/tabular/forecasting/test_data.py
index ad640da9b8..db01c6e87b 100644
--- a/tests/tabular/forecasting/test_data.py
+++ b/tests/tabular/forecasting/test_data.py
@@ -14,7 +14,6 @@
from unittest.mock import MagicMock, patch
import pytest
-
from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE
from flash.tabular.forecasting import TabularForecastingData
diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py
index ca0fe2e41a..c56abbd014 100644
--- a/tests/tabular/forecasting/test_model.py
+++ b/tests/tabular/forecasting/test_model.py
@@ -13,14 +13,14 @@
# limitations under the License.
from typing import Any
+import flash
import pytest
import torch
-from torch import Tensor
-
-import flash
from flash import DataKeys
from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE
from flash.tabular.forecasting import TabularForecaster
+from torch import Tensor
+
from tests.helpers.task_tester import StaticDataset, TaskTester
if _TOPIC_TABULAR_AVAILABLE:
diff --git a/tests/tabular/regression/test_data_model_integration.py b/tests/tabular/regression/test_data_model_integration.py
index 0a01bac532..220d5f77aa 100644
--- a/tests/tabular/regression/test_data_model_integration.py
+++ b/tests/tabular/regression/test_data_model_integration.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
import pytorch_lightning as pl
-
from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE
from flash.tabular import TabularRegressionData, TabularRegressor
@@ -48,8 +47,7 @@
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
- # ("category_embedding", # todo: seems to be bug in tabular
- # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
+ ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
@@ -82,8 +80,7 @@ def test_regression_data_frame(backbone, fields, tmpdir):
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
- # ("category_embedding", # todo: seems to be bug in tabular
- # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
+ ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
@@ -113,8 +110,7 @@ def test_regression_dicts(backbone, fields, tmpdir):
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
- # ("category_embedding", # todo: seems to be bug in tabular
- # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
+ ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py
index 52dd36ed9b..c10e579a86 100644
--- a/tests/tabular/regression/test_model.py
+++ b/tests/tabular/regression/test_model.py
@@ -14,15 +14,15 @@
from typing import Any
from unittest.mock import patch
+import flash
import pandas as pd
import pytest
import torch
-from torch import Tensor
-
-import flash
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TABULAR_AVAILABLE
from flash.tabular import TabularRegressionData, TabularRegressor
+from torch import Tensor
+
from tests.helpers.task_tester import StaticDataset, TaskTester
@@ -53,7 +53,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
- # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
+ {"backbone": "category_embedding"},
],
)
],
@@ -66,7 +66,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
- # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
+ {"backbone": "category_embedding"},
],
)
],
@@ -79,7 +79,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
- # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
+ {"backbone": "category_embedding"},
],
)
],
diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py
index fc2963e8d6..aa2a406576 100644
--- a/tests/template/classification/test_data.py
+++ b/tests/template/classification/test_data.py
@@ -13,7 +13,6 @@
# limitations under the License.
import numpy as np
import pytest
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _SKLEARN_AVAILABLE, _TOPIC_CORE_AVAILABLE
from flash.template.classification.data import TemplateData
diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py
index 5b5f54d649..408bde4621 100644
--- a/tests/template/classification/test_model.py
+++ b/tests/template/classification/test_model.py
@@ -16,13 +16,12 @@
import numpy as np
import pytest
import torch
-from torch import Tensor
-
from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _SKLEARN_AVAILABLE, _TOPIC_CORE_AVAILABLE
from flash.template import TemplateSKLearnClassifier
from flash.template.classification.data import TemplateData
+from torch import Tensor
if _SKLEARN_AVAILABLE:
from sklearn import datasets
diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py
index e19dcd2940..dcdf095c3e 100644
--- a/tests/text/classification/test_data.py
+++ b/tests/text/classification/test_data.py
@@ -16,7 +16,6 @@
import pandas as pd
import pytest
-
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import TextClassificationData
diff --git a/tests/text/classification/test_data_model_integration.py b/tests/text/classification/test_data_model_integration.py
index b4ee5d16c8..b5b06cfdfb 100644
--- a/tests/text/classification/test_data_model_integration.py
+++ b/tests/text/classification/test_data_model_integration.py
@@ -15,7 +15,6 @@
from pathlib import Path
import pytest
-
from flash.core.trainer import Trainer
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import TextClassificationData, TextClassifier
diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py
index 63830a97df..194a44bbe3 100644
--- a/tests/text/classification/test_model.py
+++ b/tests/text/classification/test_model.py
@@ -14,15 +14,15 @@
from typing import Any
from unittest.mock import patch
+import flash
import pytest
import torch
-from torch import Tensor
-
-import flash
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TEXT_AVAILABLE, _TORCH_ORT_AVAILABLE
from flash.text import TextClassifier
from flash.text.ort_callback import ORTCallback
+from torch import Tensor
+
from tests.helpers.boring_model import BoringModel
from tests.helpers.task_tester import StaticDataset, TaskTester
diff --git a/tests/text/embedding/test_model.py b/tests/text/embedding/test_model.py
index 87a0ce513f..33e4e8e612 100644
--- a/tests/text/embedding/test_model.py
+++ b/tests/text/embedding/test_model.py
@@ -14,13 +14,13 @@
import os
from typing import Any
+import flash
import pytest
import torch
-from torch import Tensor
-
-import flash
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import TextClassificationData, TextEmbedder
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
# ======== Mock data ========
diff --git a/tests/text/question_answering/test_data.py b/tests/text/question_answering/test_data.py
index 90ed1b3999..1beddbca2f 100644
--- a/tests/text/question_answering/test_data.py
+++ b/tests/text/question_answering/test_data.py
@@ -17,7 +17,6 @@
import pandas as pd
import pytest
-
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import QuestionAnsweringData
diff --git a/tests/text/question_answering/test_model.py b/tests/text/question_answering/test_model.py
index 144fa1163c..98405d20c1 100644
--- a/tests/text/question_answering/test_model.py
+++ b/tests/text/question_answering/test_model.py
@@ -16,10 +16,10 @@
import pytest
import torch
-from torch import Tensor
-
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import QuestionAnsweringTask
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
TEST_BACKBONE = "distilbert-base-uncased"
diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py
index 9241acd395..6e2ea88a4f 100644
--- a/tests/text/seq2seq/summarization/test_data.py
+++ b/tests/text/seq2seq/summarization/test_data.py
@@ -15,7 +15,6 @@
from pathlib import Path
import pytest
-
from flash import DataKeys
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import SummarizationData
diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py
index 1513cbe5b6..e3bab32fad 100644
--- a/tests/text/seq2seq/summarization/test_model.py
+++ b/tests/text/seq2seq/summarization/test_model.py
@@ -16,11 +16,11 @@
import pytest
import torch
-from torch import Tensor
-
from flash import DataKeys
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TEXT_AVAILABLE
from flash.text import SummarizationTask
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
TEST_BACKBONE = "sshleifer/tiny-mbart" # tiny model for testing
diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py
index 45663c9198..5d0f2f025d 100644
--- a/tests/text/seq2seq/translation/test_data.py
+++ b/tests/text/seq2seq/translation/test_data.py
@@ -15,7 +15,6 @@
from pathlib import Path
import pytest
-
from flash import DataKeys
from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE
from flash.text import TranslationData
diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py
index e6608de7ec..da9d61b9c0 100644
--- a/tests/text/seq2seq/translation/test_model.py
+++ b/tests/text/seq2seq/translation/test_model.py
@@ -16,11 +16,11 @@
import pytest
import torch
-from torch import Tensor
-
from flash import DataKeys
from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TEXT_AVAILABLE
from flash.text import TranslationTask
+from torch import Tensor
+
from tests.helpers.task_tester import TaskTester
TEST_BACKBONE = "sshleifer/tiny-mbart" # tiny model for testing
diff --git a/tests/video/classification/test_data.py b/tests/video/classification/test_data.py
index d174e2bd5f..b3c3f289ae 100644
--- a/tests/video/classification/test_data.py
+++ b/tests/video/classification/test_data.py
@@ -15,7 +15,6 @@
import pytest
import torch
-
from flash.core.utilities.imports import _TOPIC_VIDEO_AVAILABLE
from flash.video.classification.data import VideoClassificationData
diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py
index d30f7d7d6e..fc60b075ce 100644
--- a/tests/video/classification/test_model.py
+++ b/tests/video/classification/test_model.py
@@ -17,16 +17,16 @@
from pathlib import Path
from typing import Any
+import flash
import pytest
import torch
+from flash.core.data.io.input import DataKeys
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_VIDEO_AVAILABLE
+from flash.video import VideoClassificationData, VideoClassifier
from pandas import DataFrame
from torch import Tensor
from torch.utils.data import SequentialSampler
-import flash
-from flash.core.data.io.input import DataKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_VIDEO_AVAILABLE
-from flash.video import VideoClassificationData, VideoClassifier
from tests.helpers.task_tester import TaskTester
from tests.video.classification.test_data import create_dummy_video_frames, temp_encoded_tensors
@@ -75,6 +75,7 @@ def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=No
"""Creates a temporary lossless, mp4 video with synthetic content.
Uses a context which deletes the video after exit.
+
"""
# Lossless options.
video_codec = "libx264rgb"
@@ -93,6 +94,7 @@ def mock_video_data_frame():
Returns a labeled video file which points to this mock encoded video dataset, the ordered label and videos tuples
and the video duration in seconds.
+
"""
num_frames = 10
fps = 5
@@ -127,6 +129,7 @@ def mock_encoded_video_dataset_folder(tmpdir):
"""Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2.
Returns a directory that to this mock encoded video dataset and the video duration in seconds.
+
"""
num_frames = 10
fps = 5