From 076823a46114605e7c02982dd5d5cdc2c44d94de Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Feb 2023 14:18:25 +0100 Subject: [PATCH 01/39] Bump Lightning-AI/utilities from 0.4.0 to 0.6.0 (#1521) Bumps [Lightning-AI/utilities](https://github.com/Lightning-AI/utilities) from 0.4.0 to 0.6.0. - [Release notes](https://github.com/Lightning-AI/utilities/releases) - [Changelog](https://github.com/Lightning-AI/utilities/blob/main/CHANGELOG.md) - [Commits](https://github.com/Lightning-AI/utilities/compare/v0.4.0...v0.6.0) --- updated-dependencies: - dependency-name: Lightning-AI/utilities dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci-schema.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index 69ef7dcbdc..ba108cae3e 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,6 +8,6 @@ on: jobs: check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.4.0 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.6.0 with: azure-dir: '' From 1b9a1c3228bdfdf2798642d3010757ea7485a039 Mon Sep 17 00:00:00 2001 From: Phil Pirozhkov Date: Mon, 20 Feb 2023 16:25:22 +0300 Subject: [PATCH 02/39] Fix BAAL scheme image URL (#1520) --- docs/source/integrations/baal.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/integrations/baal.rst b/docs/source/integrations/baal.rst index 6aa8172a77..163af500d7 100644 --- a/docs/source/integrations/baal.rst +++ b/docs/source/integrations/baal.rst @@ -22,7 +22,7 @@ The most uncertain samples will be labelled by the human to accelerate the model .. raw:: html
- +

Credit to ElementAI / Baal Team for creating this diagram flow


From a7d8d49465e408750a92aca751d4264e4b73cea2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Feb 2023 15:21:06 +0100 Subject: [PATCH 03/39] Bump Lightning-AI/utilities from 0.6.0 to 0.7.1 (#1523) Bumps [Lightning-AI/utilities](https://github.com/Lightning-AI/utilities) from 0.6.0 to 0.7.1. - [Release notes](https://github.com/Lightning-AI/utilities/releases) - [Changelog](https://github.com/Lightning-AI/utilities/blob/v0.7.1/CHANGELOG.md) - [Commits](https://github.com/Lightning-AI/utilities/compare/v0.6.0...v0.7.1) --- updated-dependencies: - dependency-name: Lightning-AI/utilities dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci-schema.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index ba108cae3e..6ed51d72a4 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,6 +8,6 @@ on: jobs: check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.6.0 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.7.1 with: azure-dir: '' From c9cce277c8f9d4c460788786dcca631f8f99d5b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Mar 2023 02:21:03 +0100 Subject: [PATCH 04/39] [pre-commit.ci] pre-commit suggestions (#1506) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit suggestions updates: - [github.com/PyCQA/docformatter: v1.5.0 → v1.5.1](https://github.com/PyCQA/docformatter/compare/v1.5.0...v1.5.1) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- docs/extensions/stability.py | 1 - flash/audio/classification/input_transform.py | 1 - flash/core/classification.py | 2 -- flash/core/data/io/input.py | 1 - flash/core/data/io/input_transform.py | 1 - flash/core/data/utilities/sort.py | 3 +-- flash/core/integrations/icevision/transforms.py | 1 - flash/core/integrations/transformers/collate.py | 1 - flash/core/serve/core.py | 1 - flash/core/serve/decorators.py | 1 - flash/core/utilities/flash_cli.py | 1 - flash/core/utilities/providers.py | 1 - flash/image/classification/input_transform.py | 1 - flash/image/classification/integrations/baal/loop.py | 1 - flash/image/instance_segmentation/data.py | 1 - flash/image/style_transfer/input_transform.py | 1 - flash/pointcloud/detection/data.py | 1 - flash/pointcloud/detection/open3d_ml/app.py | 1 - flash/tabular/data.py | 1 - flash/text/classification/adapters.py | 2 -- tests/audio/speech_recognition/test_model.py | 1 - tests/image/classification/test_model.py | 1 - tests/image/embedding/test_model.py | 1 - tests/image/keypoint_detection/test_model.py | 1 - tests/image/segmentation/test_model.py | 1 - tests/image/style_transfer/test_model.py | 1 - tests/tabular/forecasting/test_model.py | 1 - tests/text/embedding/test_model.py | 1 - tests/video/classification/test_model.py | 1 - 30 files changed, 2 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3db49f5e0..246143ca74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,7 +54,7 @@ repos: - id: nbstripout - repo: https://github.com/PyCQA/docformatter - rev: v1.5.0 + rev: v1.5.1 hooks: - id: docformatter args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] diff --git a/docs/extensions/stability.py b/docs/extensions/stability.py index 01f1168004..e979f63849 100644 --- a/docs/extensions/stability.py +++ b/docs/extensions/stability.py @@ -39,7 +39,6 @@ class Beta(Directive): final_argument_whitespace = True def run(self): - message = self.arguments[-1].strip() admonition_rst = ADMONITION_TEMPLATE.format(type="beta", title="Beta", message=message) diff --git a/flash/audio/classification/input_transform.py b/flash/audio/classification/input_transform.py index 22384d21e8..1cb876e5c4 100644 --- a/flash/audio/classification/input_transform.py +++ b/flash/audio/classification/input_transform.py @@ -30,7 +30,6 @@ @dataclass class AudioClassificationInputTransform(InputTransform): - spectrogram_size: Tuple[int, int] = (128, 128) time_mask_param: Optional[int] = None freq_mask_param: Optional[int] = None diff --git a/flash/core/classification.py b/flash/core/classification.py index 70962224f6..5d94829945 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -78,7 +78,6 @@ def to_metrics_format(self, x: Tensor) -> Tensor: class ClassificationTask(ClassificationMixin, Task): - outputs: FlashRegistry = Task.outputs + CLASSIFICATION_OUTPUTS def __init__( @@ -103,7 +102,6 @@ def __init__( class ClassificationAdapterTask(ClassificationMixin, AdapterTask): - outputs: FlashRegistry = Task.outputs + CLASSIFICATION_OUTPUTS def __init__( diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index c34ec82fb3..8c1e669bea 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -168,7 +168,6 @@ class InputBase(Properties, metaclass=_InputMeta): """ def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> None: - super().__init__(running_stage=running_stage) self.data = None diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index 86ed0f05bb..a8694bc5b7 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -27,7 +27,6 @@ class InputTransformPlacement(LightningEnum): - PER_SAMPLE_TRANSFORM = "per_sample_transform" PER_BATCH_TRANSFORM = "per_batch_transform" COLLATE = "collate" diff --git a/flash/core/data/utilities/sort.py b/flash/core/data/utilities/sort.py index c5a05d42dd..36817c6531 100644 --- a/flash/core/data/utilities/sort.py +++ b/flash/core/data/utilities/sort.py @@ -27,7 +27,6 @@ def sorted_alphanumeric(iterable: Iterable[str]) -> Iterable[str]: """Sort the given iterable in the way that humans expect. For example, given ``{"class_1", "class_11", "class_2"}`` this returns ``["class_1", "class_2", "class_11"]``. - Copied from: - https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/ + Copied from: https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/ """ return sorted(iterable, key=_alphanumeric_key) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index d85a46ddd1..7bc4c083e5 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -283,7 +283,6 @@ def forward(self, x): @dataclass class IceVisionInputTransform(InputTransform): - image_size: int = 128 @requires("image", "icevision") diff --git a/flash/core/integrations/transformers/collate.py b/flash/core/integrations/transformers/collate.py index fc7b7a6682..b78f2aa87a 100644 --- a/flash/core/integrations/transformers/collate.py +++ b/flash/core/integrations/transformers/collate.py @@ -25,7 +25,6 @@ @dataclass(unsafe_hash=True) class TransformersCollate: - backbone: str tokenizer_kwargs: Optional[Dict[str, Any]] = field(default_factory=dict, hash=False) diff --git a/flash/core/serve/core.py b/flash/core/serve/core.py index 78caac9fbe..986cb63ba0 100644 --- a/flash/core/serve/core.py +++ b/flash/core/serve/core.py @@ -63,7 +63,6 @@ def __post_init__(self): class FlashServeScriptLoader: - __slots__ = ("location", "instance") def __init__(self, location: FilePath): diff --git a/flash/core/serve/decorators.py b/flash/core/serve/decorators.py index 18b884c5a6..ab9de9f682 100644 --- a/flash/core/serve/decorators.py +++ b/flash/core/serve/decorators.py @@ -32,7 +32,6 @@ class UnboundMeta: @dataclass(unsafe_hash=True) class BoundMeta(UnboundMeta): - models: Union[List["Servable"], Tuple["Servable", ...], Dict[str, "Servable"]] uid: str = field(default_factory=lambda: uuid4().hex, init=False) out_attr_dict: ParameterContainer = field(default=None, init=False) diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index 27e3bc1756..bf2ef4abf4 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -115,7 +115,6 @@ def get_overlapping_args(func_a, func_b) -> Set[str]: @beta("Flash Zero is currently in Beta.") class FlashCLI(LightningCLI): - datamodule: DataModule config_init: Namespace model: LightningModule diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index b6eed0e8b4..00c019ab74 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -18,7 +18,6 @@ @dataclass class Provider: - name: str url: str diff --git a/flash/image/classification/input_transform.py b/flash/image/classification/input_transform.py index 050944b64b..33953853a8 100644 --- a/flash/image/classification/input_transform.py +++ b/flash/image/classification/input_transform.py @@ -43,7 +43,6 @@ def forward(self, x): @dataclass class ImageClassificationInputTransform(InputTransform): - image_size: Tuple[int, int] = (196, 196) mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) diff --git a/flash/image/classification/integrations/baal/loop.py b/flash/image/classification/integrations/baal/loop.py index 13a5cc7171..bb10a6c9cd 100644 --- a/flash/image/classification/integrations/baal/loop.py +++ b/flash/image/classification/integrations/baal/loop.py @@ -114,7 +114,6 @@ def on_advance_start(self, *args: Any, **kwargs: Any) -> None: self.progress.increment_ready() def advance(self, *args: Any, **kwargs: Any) -> None: - self.progress.increment_started() if self.trainer.datamodule.has_labelled_data: diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 0c7c80ae4a..d6533ab074 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -63,7 +63,6 @@ def per_sample_transform(self, sample: Any) -> Any: class InstanceSegmentationData(DataModule): - input_transform_cls = IceVisionInputTransform output_transform_cls = InstanceSegmentationOutputTransform diff --git a/flash/image/style_transfer/input_transform.py b/flash/image/style_transfer/input_transform.py index 295870370c..414a3f4b9b 100644 --- a/flash/image/style_transfer/input_transform.py +++ b/flash/image/style_transfer/input_transform.py @@ -25,7 +25,6 @@ @dataclass class StyleTransferInputTransform(InputTransform): - image_size: int = 256 def per_sample_transform(self) -> Callable: diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index 6a7306b691..cb550343fa 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -30,7 +30,6 @@ @beta("Point cloud object detection is currently in Beta.") class PointCloudObjectDetectorData(DataModule): - input_transform_cls = InputTransform @classmethod diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py index 9968a707ef..8022a94e71 100644 --- a/flash/pointcloud/detection/open3d_ml/app.py +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -103,7 +103,6 @@ def on_done_ui(): gui.Application.instance.run() class VizDataset(Dataset): - name = "VizDataset" def __init__(self, dataset): diff --git a/flash/tabular/data.py b/flash/tabular/data.py index dbebb69594..42739d5440 100644 --- a/flash/tabular/data.py +++ b/flash/tabular/data.py @@ -19,7 +19,6 @@ class TabularData(DataModule): - input_transform_cls = InputTransform output_transform_cls = OutputTransform diff --git a/flash/text/classification/adapters.py b/flash/text/classification/adapters.py index d2f57423da..517dd85da9 100644 --- a/flash/text/classification/adapters.py +++ b/flash/text/classification/adapters.py @@ -88,7 +88,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A @dataclass class GenericCollate: - tokenizer: Callable[[str], Any] @staticmethod @@ -110,7 +109,6 @@ def __call__(self, samples): class GenericAdapter(Adapter): - heads: FlashRegistry = CLASSIFIER_HEADS def __init__(self, backbone, num_classes: int, max_length: int = 128, head="linear"): diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 84dcbda651..19e53b70f7 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -28,7 +28,6 @@ class TestSpeechRecognition(TaskTester): - task = SpeechRecognition task_kwargs = dict(backbone=TEST_BACKBONE) cli_command = "speech_recognition" diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 9879d0c8cd..b3a146c925 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -45,7 +45,6 @@ def __len__(self) -> int: class TestImageClassifier(TaskTester): - task = ImageClassifier task_args = (2,) cli_command = "image_classification" diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index bc8f16f924..9f2ec8e476 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -29,7 +29,6 @@ class TestImageEmbedder(TaskTester): - task = ImageEmbedder task_kwargs = dict( backbone="resnet18", diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py index b0980961cd..328d235890 100644 --- a/tests/image/keypoint_detection/test_model.py +++ b/tests/image/keypoint_detection/test_model.py @@ -94,7 +94,6 @@ def coco_keypoints(tmpdir): class TestKeypointDetector(TaskTester): - task = KeypointDetector task_args = (2,) task_kwargs = {"num_classes": 2} diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index ff619fb0ae..1d6e7d353c 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -28,7 +28,6 @@ class TestSemanticSegmentation(TaskTester): - task = SemanticSegmentation task_args = (2,) cli_command = "semantic_segmentation" diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index 3ad7fd71eb..7ce7a8dd10 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -24,7 +24,6 @@ class TestStyleTransfer(TaskTester): - task = StyleTransfer cli_command = "style_transfer" is_testing = _IMAGE_TESTING diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py index 9e0e8adabb..a786b8196b 100644 --- a/tests/tabular/forecasting/test_model.py +++ b/tests/tabular/forecasting/test_model.py @@ -31,7 +31,6 @@ class TestTabularForecaster(TaskTester): - task = TabularForecaster # TODO: Reduce number of required parameters task_kwargs = { diff --git a/tests/text/embedding/test_model.py b/tests/text/embedding/test_model.py index 75886f56f8..f295f97e17 100644 --- a/tests/text/embedding/test_model.py +++ b/tests/text/embedding/test_model.py @@ -37,7 +37,6 @@ class TestTextEmbedder(TaskTester): - task = TextEmbedder task_kwargs = {"backbone": TEST_BACKBONE} is_testing = _TEXT_TESTING diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index 1348eb8264..d0f9e7615c 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -39,7 +39,6 @@ class TestVideoClassifier(TaskTester): - task = VideoClassifier task_args = (2,) task_kwargs = {"pretrained": False, "backbone": "slow_r50"} From 8cd41f5a171b018af8776f99ae80f71b2443fa97 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Sat, 4 Mar 2023 02:36:01 +0100 Subject: [PATCH 05/39] update precommit & formatting (#1525) * update precommit * update * fixing --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 14 ++++----- docs/source/conf.py | 1 - docs/source/general/finetuning.rst | 1 + docs/source/general/registry.rst | 2 +- .../source/reference/image_classification.rst | 1 - docs/source/template/task.rst | 1 - flash/__init__.py | 1 - flash/audio/classification/data.py | 8 ++--- flash/audio/speech_recognition/data.py | 15 +++++---- flash/core/adapter.py | 8 ++--- flash/core/classification.py | 13 +++----- flash/core/data/data_module.py | 5 ++- flash/core/data/io/input.py | 26 ++++++---------- flash/core/data/io/input_transform.py | 27 +++------------- flash/core/data/io/output.py | 4 +-- flash/core/data/io/output_transform.py | 4 +-- flash/core/data/transforms.py | 6 ++-- flash/core/data/utilities/classification.py | 4 +-- flash/core/data/utilities/collate.py | 6 ++-- flash/core/data/utilities/paths.py | 4 +-- flash/core/data/utilities/sort.py | 4 +-- flash/core/finetuning.py | 7 ++--- flash/core/integrations/labelstudio/input.py | 2 -- .../integrations/pytorch_tabular/adapter.py | 1 - flash/core/model.py | 12 +++---- flash/core/optimizers/lr_scheduler.py | 4 +-- flash/core/registry.py | 7 ++--- flash/core/regression.py | 1 - flash/core/serve/composition.py | 1 - flash/core/serve/dag/rewrite.py | 4 +-- flash/core/serve/dag/task.py | 3 +- flash/core/trainer.py | 11 +++---- flash/core/utilities/lightning_cli.py | 10 +++--- flash/core/utilities/stability.py | 6 ++-- flash/graph/collate.py | 4 +-- flash/graph/embedding/model.py | 4 +-- flash/image/classification/adapters.py | 7 ++--- .../image/classification/backbones/resnet.py | 2 -- flash/image/classification/backbones/timm.py | 1 - flash/image/classification/data.py | 19 ++++++------ flash/image/detection/backbones.py | 1 - flash/image/detection/data.py | 31 +++++++++---------- flash/image/embedding/model.py | 7 ++--- flash/image/embedding/vissl/adapter.py | 1 - .../embedding/vissl/transforms/multicrop.py | 3 +- flash/image/face_detection/data.py | 3 -- flash/image/instance_segmentation/data.py | 12 +++---- flash/image/keypoint_detection/data.py | 1 - flash/image/segmentation/backbones.py | 1 - flash/image/segmentation/data.py | 12 +++---- flash/image/segmentation/heads.py | 1 - flash/image/segmentation/output.py | 8 ++--- flash/image/style_transfer/backbones.py | 1 - flash/image/style_transfer/data.py | 7 ++--- flash/pointcloud/detection/data.py | 3 -- flash/pointcloud/detection/model.py | 1 - flash/pointcloud/detection/open3d_ml/app.py | 1 - .../detection/open3d_ml/backbones.py | 1 - flash/pointcloud/detection/open3d_ml/input.py | 2 -- flash/pointcloud/segmentation/data.py | 4 --- .../pointcloud/segmentation/open3d_ml/app.py | 2 -- .../open3d_ml/sequences_dataset.py | 2 -- flash/tabular/classification/data.py | 12 +++---- flash/tabular/forecasting/data.py | 3 +- flash/tabular/forecasting/model.py | 1 - flash/tabular/input.py | 1 - flash/tabular/regression/data.py | 10 +++--- flash/template/classification/data.py | 11 +++---- flash/template/classification/model.py | 8 ++--- .../classification/backbones/huggingface.py | 1 - flash/text/classification/collate.py | 1 - flash/text/classification/data.py | 8 ++--- flash/text/embedding/model.py | 4 +-- flash/text/question_answering/data.py | 4 +-- flash/text/question_answering/model.py | 5 ++- flash/text/seq2seq/summarization/data.py | 16 +++++----- flash/text/seq2seq/translation/data.py | 20 ++++++------ flash/video/classification/data.py | 11 +++---- flash/video/classification/input_transform.py | 1 - .../image_classification_imagenette_mini.py | 1 - tests/core/data/test_base_viz.py | 2 -- .../serve/test_compat/test_cached_property.py | 1 - tests/core/serve/test_integration.py | 1 - tests/core/test_model.py | 1 - tests/graph/classification/test_model.py | 1 - tests/graph/embedding/test_model.py | 1 - tests/helpers/boring_model.py | 1 - tests/helpers/task_tester.py | 8 ++--- tests/image/classification/test_data.py | 1 - tests/image/detection/test_model.py | 1 - .../image/instance_segmentation/test_model.py | 1 - tests/pointcloud/detection/test_model.py | 1 - tests/pointcloud/segmentation/test_model.py | 1 - tests/tabular/classification/test_model.py | 1 - tests/tabular/forecasting/test_data.py | 8 ++--- tests/tabular/regression/test_model.py | 1 - tests/text/classification/test_model.py | 1 - tests/text/question_answering/test_model.py | 1 - .../text/seq2seq/summarization/test_model.py | 1 - tests/text/seq2seq/translation/test_model.py | 1 - tests/video/classification/test_model.py | 5 --- 101 files changed, 202 insertions(+), 322 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 246143ca74..e29597b17f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,15 +38,14 @@ repos: rev: v3.3.1 hooks: - id: pyupgrade - args: [--py36-plus] + args: [--py37-plus] name: Upgrade code - repo: https://github.com/PyCQA/isort - rev: 5.11.4 + rev: 5.12.0 hooks: - id: isort name: imports - require_serial: false - repo: https://github.com/kynan/nbstripout rev: 0.6.1 @@ -54,23 +53,22 @@ repos: - id: nbstripout - repo: https://github.com/PyCQA/docformatter - rev: v1.5.1 + rev: v1.5.0 hooks: - id: docformatter - args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] + args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 23.1.0 hooks: - id: black name: Format code - repo: https://github.com/asottile/blacken-docs - rev: v1.12.1 + rev: 1.13.0 hooks: - id: blacken-docs args: [ --line-length=120, --skip-errors ] - additional_dependencies: [ black==21.10b0 ] - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index e031db8430..d574567d77 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,7 +47,6 @@ def _load_py_module(fname, pkg="flash"): from flash.core.utilities import providers except ModuleNotFoundError: - about = _load_py_module("__about__.py") providers = _load_py_module("core/utilities/providers.py") diff --git a/docs/source/general/finetuning.rst b/docs/source/general/finetuning.rst index f71ab35bdc..0854740d77 100644 --- a/docs/source/general/finetuning.rst +++ b/docs/source/general/finetuning.rst @@ -216,6 +216,7 @@ For even more customization, create your own finetuning callback. Learn more abo from flash.core.finetuning import FlashBaseFinetuning + # Create a finetuning callback class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning): def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True): diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst index 0cf78aa552..356071614d 100644 --- a/docs/source/general/registry.rst +++ b/docs/source/general/registry.rst @@ -37,7 +37,6 @@ It is good practice to associate one or multiple registry to a Task as follow: # creating a custom `Task` with its own registry class MyImageClassifier(Task): - backbones = FlashRegistry("backbones") def __init__( @@ -67,6 +66,7 @@ Your custom functions can be registered within a :class:`~flash.core.registry.Fl # HINT 1: Use `from functools import partial` if you want to store some arguments. MyImageClassifier.backbones(fn=partial(fn, backbone="my_backbone"), name="username/partial_backbone") + # Option 2: Using decorator. @MyImageClassifier.backbones(name="username/decorated_backbone") def fn(pretrained: bool = True): diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index f2e9667675..20a830c8dd 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -115,7 +115,6 @@ Here's an example: @dataclass class ImageClassificationInputTransform(InputTransform): - image_size: Tuple[int, int] = (196, 196) mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) diff --git a/docs/source/template/task.rst b/docs/source/template/task.rst index 1b83370718..6dc1496d70 100644 --- a/docs/source/template/task.rst +++ b/docs/source/template/task.rst @@ -16,7 +16,6 @@ You should attach your backbones registry as a class attribute like this: .. code-block:: python class TemplateSKLearnClassifier(ClassificationTask): - backbones: FlashRegistry = TEMPLATE_BACKBONES Model architecture and hyper-parameters diff --git a/flash/__init__.py b/flash/__init__.py index 46f83864d6..25c88b2f8c 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -18,7 +18,6 @@ from flash.core.utilities.imports import _TORCH_AVAILABLE if _TORCH_AVAILABLE: - from flash.core.data.callback import FlashCallback from flash.core.data.data_module import DataModule from flash.core.data.io.input import DataKeys, Input diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index 7b802475e7..eb21fce711 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -312,8 +312,8 @@ def from_numpy( target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": - """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from numpy arrays (or lists - of arrays) and corresponding lists of targets. + """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from numpy arrays (or lists of + arrays) and corresponding lists of targets. The targets can be in any of our :ref:`supported classification target formats `. @@ -400,8 +400,8 @@ def from_tensors( target_formatter: Optional[TargetFormatter] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": - """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from torch tensors (or lists - of tensors) and corresponding lists of targets. + """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from torch tensors (or lists of + tensors) and corresponding lists of targets. The targets can be in any of our :ref:`supported classification target formats `. diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index e7c64d0c03..4b79bcafa5 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -56,8 +56,8 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": - """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from lists of audio files - and corresponding lists of targets. + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from lists of audio files and + corresponding lists of targets. The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, ``.pvf``, ``.rf64``, ``.ircam``, ``.voc``, ``.w64``, @@ -148,8 +148,8 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": - """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from CSV files containing - audio file paths and their corresponding targets. + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from CSV files containing audio + file paths and their corresponding targets. Input audio file paths will be extracted from the ``input_field`` column in the CSV files. The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, @@ -337,8 +337,8 @@ def from_json( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": - """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from JSON files containing - audio file paths and their corresponding targets. + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from JSON files containing audio + file paths and their corresponding targets. Input audio file paths will be extracted from the ``input_field`` field in the JSON files. The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, @@ -459,8 +459,7 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": - """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from PyTorch Dataset - objects. + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from PyTorch Dataset objects. The Dataset objects should be one of the following: diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 12748e4b1e..d9c5328739 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -26,8 +26,8 @@ class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module): - """The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular - provider within a :class:`~flash.core.model.Task`.""" + """The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular provider + within a :class:`~flash.core.model.Task`.""" @classmethod @abstractmethod @@ -67,8 +67,8 @@ def identity_collate_fn(x): class AdapterTask(Task): - """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` - and forwards all of the hooks. + """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` and + forwards all of the hooks. Args: adapter: The :class:`~flash.core.adapter.Adapter` to wrap. diff --git a/flash/core/classification.py b/flash/core/classification.py index 5d94829945..a236e1486d 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -90,7 +90,6 @@ def __init__( labels: Optional[List[str]] = None, **kwargs, ) -> None: - metrics, loss_fn = self._build(num_classes, labels, loss_fn, metrics, multi_label) super().__init__( @@ -114,7 +113,6 @@ def __init__( labels: Optional[List[str]] = None, **kwargs, ) -> None: - metrics, loss_fn = self._build(num_classes, labels, loss_fn, metrics, multi_label) super().__init__( @@ -170,8 +168,7 @@ def transform(self, sample: Any) -> Any: @CLASSIFICATION_OUTPUTS(name="probabilities") class ProbabilitiesOutput(PredsClassificationOutput): - """A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a - list.""" + """A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a list.""" def transform(self, sample: Any) -> Any: sample = super().transform(sample) @@ -182,8 +179,8 @@ def transform(self, sample: Any) -> Any: @CLASSIFICATION_OUTPUTS(name="classes") class ClassesOutput(PredsClassificationOutput): - """A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and - converts to a list. + """A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and converts to + a list. Args: multi_label: If true, treats outputs as multi label logits. @@ -209,8 +206,8 @@ def transform(self, sample: Any) -> Union[int, List[int]]: @CLASSIFICATION_OUTPUTS(name="labels") class LabelsOutput(ClassesOutput): - """A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the - argmax classification. + """A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the argmax + classification. Args: labels: A list of labels, assumed to map the class index to the label for that class. diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index b10234da61..ccc944d142 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -124,7 +124,6 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, ) -> None: - if not batch_size: raise TypeError("The `batch_size` should be provided to the DataModule on instantiation.") @@ -561,8 +560,8 @@ def _split_train_val( train_dataset: Dataset, val_split: float, ) -> Tuple[Any, Any]: - """Utility function for splitting the training dataset into a disjoint subset of training samples and - validation samples. + """Utility function for splitting the training dataset into a disjoint subset of training samples and validation + samples. Args: train_dataset: A instance of a :class:`torch.utils.data.Dataset`. diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 8c1e669bea..0a35195792 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -13,7 +13,6 @@ # limitations under the License. import functools import os -import sys from enum import Enum from typing import Any, cast, Dict, Iterable, List, Sequence, Tuple, Union @@ -24,10 +23,7 @@ from flash.core.data.utils import _STAGES_PREFIX from flash.core.utilities.stages import RunningStage -if sys.version_info < (3, 7): - from typing import GenericMeta -else: - GenericMeta = type +GenericMeta = type if not os.environ.get("READTHEDOCS", False): @@ -69,8 +65,7 @@ def __hash__(self) -> int: class DataKeys(LightningEnum): - """The ``DataKeys`` enum contains the keys that are used by built-in data sources to refer to inputs and - targets.""" + """The ``DataKeys`` enum contains the keys that are used by built-in data sources to refer to inputs and targets.""" INPUT = "input" PREDS = "preds" @@ -103,8 +98,8 @@ def _has_len(data: Union[Sequence, Iterable]) -> bool: def _validate_input(input: "InputBase") -> None: - """Helper function to validate that the type of an ``InputBase.data`` is appropriate for the type of - ``InputBase`` being used. + """Helper function to validate that the type of an ``InputBase.data`` is appropriate for the type of ``InputBase`` + being used. Args: input: The ``InputBase`` instance to validate. @@ -191,9 +186,9 @@ def _call_load_sample(self, sample: Any) -> Any: @staticmethod def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: - """The ``load_data`` hook should return a collection of samples. To reduce the memory footprint, these - samples should typically not have been loaded. For example, an input which loads images from disk would - only return the list of filenames here rather than the loaded images. + """The ``load_data`` hook should return a collection of samples. To reduce the memory footprint, these samples + should typically not have been loaded. For example, an input which loads images from disk would only return the + list of filenames here rather than the loaded images. Args: *args: Any arguments that the input requires. @@ -239,8 +234,8 @@ def predict_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterab @staticmethod def load_sample(sample: Dict[str, Any]) -> Any: - """The ``load_sample`` hook is called for each ``__getitem__`` or ``__next__`` call to the dataset with a - single sample from the output of the ``load_data`` hook as input. + """The ``load_sample`` hook is called for each ``__getitem__`` or ``__next__`` call to the dataset with a single + sample from the output of the ``load_data`` hook as input. Args: sample: A single sample from the output of the ``load_data`` hook. @@ -272,8 +267,7 @@ def test_load_sample(self, sample: Dict[str, Any]) -> Any: return self.load_sample(sample) def predict_load_sample(self, sample: Dict[str, Any]) -> Any: - """Override the ``predict_load_sample`` hook with data loading logic that is only required during - predicting. + """Override the ``predict_load_sample`` hook with data loading logic that is only required during predicting. Args: sample: A single sample from the output of the ``load_data`` hook. diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index a8694bc5b7..f78d1c1c1b 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -83,7 +83,6 @@ def per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ pass @@ -121,7 +120,6 @@ def val_per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_sample_transform() @@ -159,7 +157,6 @@ def predict_per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_sample_transform() @@ -184,7 +181,6 @@ def serve_per_sample_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_sample_transform() @@ -213,7 +209,6 @@ def per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ pass @@ -251,7 +246,6 @@ def val_per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device() @@ -289,7 +283,6 @@ def predict_per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device() @@ -314,7 +307,6 @@ def serve_per_sample_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def serve_per_sample_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device() @@ -343,7 +335,6 @@ def per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ pass @@ -381,7 +372,6 @@ def val_per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_batch_transform() @@ -419,7 +409,6 @@ def predict_per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_batch_transform() @@ -444,7 +433,6 @@ def serve_per_batch_transform(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_batch_transform() @@ -473,7 +461,6 @@ def per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ pass @@ -511,7 +498,6 @@ def val_per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device() @@ -549,7 +535,6 @@ def predict_per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device() @@ -574,7 +559,6 @@ def serve_per_batch_transform_on_device(self) -> Callable: class MyInputTransform(InputTransform): def serve_per_batch_transform_on_device(self) -> Callable: - return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device() @@ -677,7 +661,6 @@ def __resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str # iterate over all transforms hook name for transform_name in InputTransformPlacement: - transform_name = transform_name.value method_name = f"{stage}_{transform_name}" @@ -726,7 +709,6 @@ def create_or_configure_input_transform( transform: INPUT_TRANSFORM_TYPE, transform_kwargs: Optional[Dict] = None, ) -> Optional[InputTransform]: - if not transform_kwargs: transform_kwargs = {} @@ -831,8 +813,8 @@ def __str__(self) -> str: def __make_collates(input_transform: InputTransform, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: - """Returns the appropriate collate functions based on whether the transforms happen in a DataLoader worker or - on the device (main process).""" + """Returns the appropriate collate functions based on whether the transforms happen in a DataLoader worker or on the + device (main process).""" if on_device: return input_transform._identity, collate return collate, input_transform._identity @@ -841,7 +823,6 @@ def __make_collates(input_transform: InputTransform, on_device: bool, collate: C def __configure_worker_and_device_collate_fn( running_stage: RunningStage, input_transform: InputTransform ) -> Tuple[Callable, Callable]: - transform_for_stage: _InputTransformPerStage = input_transform._transform[running_stage] worker_collate_fn, device_collate_fn = __make_collates( @@ -854,8 +835,8 @@ def __configure_worker_and_device_collate_fn( def create_worker_input_transform_processor( running_stage: RunningStage, input_transform: InputTransform ) -> _InputTransformProcessor: - """This utility is used to create the 2 `_InputTransformProcessor` objects which contain the transforms used as - the DataLoader `collate_fn`.""" + """This utility is used to create the 2 `_InputTransformProcessor` objects which contain the transforms used as the + DataLoader `collate_fn`.""" worker_collate_fn, _ = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) diff --git a/flash/core/data/io/output.py b/flash/core/data/io/output.py index 8802d125f8..0b8e7467a8 100644 --- a/flash/core/data/io/output.py +++ b/flash/core/data/io/output.py @@ -19,8 +19,8 @@ class Output(Properties): - """An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which - is used to convert the model output into the desired output format when predicting.""" + """An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which is + used to convert the model output into the desired output format when predicting.""" @classmethod @abstractmethod diff --git a/flash/core/data/io/output_transform.py b/flash/core/data/io/output_transform.py index 6e9e27dbe6..0e691ce51a 100644 --- a/flash/core/data/io/output_transform.py +++ b/flash/core/data/io/output_transform.py @@ -17,8 +17,8 @@ class OutputTransform: - """The :class:`~flash.core.data.io.output_transform.OutputTransform` encapsulates all the data processing logic - that should run after the model.""" + """The :class:`~flash.core.data.io.output_transform.OutputTransform` encapsulates all the data processing logic that + should run after the model.""" @staticmethod def per_batch_transform(batch: Any) -> Any: diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index 2f77731cbf..b7696f6433 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -59,9 +59,9 @@ def forward(self, x: Any) -> Any: class ApplyToKeys(nn.Sequential): - """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from - the input. When a single key is given, a single value will be passed to the transforms. When multiple keys are - given, the corresponding values will be passed to the transforms as a list. + """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from the + input. When a single key is given, a single value will be passed to the transforms. When multiple keys are given, + the corresponding values will be passed to the transforms as a list. Args: keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. diff --git a/flash/core/data/utilities/classification.py b/flash/core/data/utilities/classification.py index 7675130b63..312c911031 100644 --- a/flash/core/data/utilities/classification.py +++ b/flash/core/data/utilities/classification.py @@ -385,8 +385,8 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]: def _resolve_target_formatter(a: Type[TargetFormatter], b: Type[TargetFormatter]) -> Type[TargetFormatter]: """The purpose of this resolution function is to enable reduction of the ``TargetFormatter`` type over multiple - targets. For example, if one target formatter type is ``CommaDelimitedMultiLabelTargetFormatter`` and the other - type is ``SingleLabelTargetFormatter``then their reduction will be ``CommaDelimitedMultiLabelTargetFormatter``. + targets. For example, if one target formatter type is ``CommaDelimitedMultiLabelTargetFormatter`` and the other type + is ``SingleLabelTargetFormatter``then their reduction will be ``CommaDelimitedMultiLabelTargetFormatter``. Raises: ValueError: If the two target formatters could not be resolved. diff --git a/flash/core/data/utilities/collate.py b/flash/core/data/utilities/collate.py index 02c1075167..54f4a75b72 100644 --- a/flash/core/data/utilities/collate.py +++ b/flash/core/data/utilities/collate.py @@ -52,9 +52,9 @@ def wrap_collate(collate): def default_collate(batch: List[Any]) -> Any: - """The :func:`flash.data.utilities.collate.default_collate` extends `torch.utils.data._utils.default_collate` - to first extract any metadata from the samples in the batch (in the ``"metadata"`` key). The list of metadata - entries will then be inserted into the collated result. + """The :func:`flash.data.utilities.collate.default_collate` extends `torch.utils.data._utils.default_collate` to + first extract any metadata from the samples in the batch (in the ``"metadata"`` key). The list of metadata entries + will then be inserted into the collated result. Args: batch: The list of samples to collate. diff --git a/flash/core/data/utilities/paths.py b/flash/core/data/utilities/paths.py index 7598604552..472903bae0 100644 --- a/flash/core/data/utilities/paths.py +++ b/flash/core/data/utilities/paths.py @@ -136,8 +136,8 @@ def filter_valid_files( *additional_lists: List[Any], valid_extensions: Optional[Tuple[str, ...]] = None, ) -> Union[List[Any], Tuple[List[Any], ...]]: - """Filter the given list of files and any additional lists to include only the entries that contain a file with - a valid extension. + """Filter the given list of files and any additional lists to include only the entries that contain a file with a + valid extension. Args: files: The list of files to filter by. diff --git a/flash/core/data/utilities/sort.py b/flash/core/data/utilities/sort.py index 36817c6531..521b2550f5 100644 --- a/flash/core/data/utilities/sort.py +++ b/flash/core/data/utilities/sort.py @@ -24,8 +24,8 @@ def _alphanumeric_key(key: str) -> List[Union[int, str]]: def sorted_alphanumeric(iterable: Iterable[str]) -> Iterable[str]: - """Sort the given iterable in the way that humans expect. For example, given ``{"class_1", "class_11", - "class_2"}`` this returns ``["class_1", "class_2", "class_11"]``. + """Sort the given iterable in the way that humans expect. For example, given ``{"class_1", "class_11", "class_2"}`` + this returns ``["class_1", "class_2", "class_11"]``. Copied from: https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/ """ diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index d2adbb02a6..71b684de4b 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -32,8 +32,8 @@ class FinetuningStrategies(LightningEnum): - """The ``FinetuningStrategies`` enum contains the keys that are used internally by the ``FlashBaseFinetuning`` - when choosing the strategy to perform.""" + """The ``FinetuningStrategies`` enum contains the keys that are used internally by the ``FlashBaseFinetuning`` when + choosing the strategy to perform.""" NO_FREEZE = "no_freeze" FREEZE = "freeze" @@ -217,8 +217,7 @@ def __init__( class FlashDeepSpeedFinetuning(FlashBaseFinetuning): - """FlashDeepSpeedFinetuning can be used to create a custom Flash Finetuning Callback which works with - DeepSpeed. + """FlashDeepSpeedFinetuning can be used to create a custom Flash Finetuning Callback which works with DeepSpeed. DeepSpeed cannot store and load its parameters when working with Lightning. So FlashDeepSpeedFinetuning overrides `_store` to not store its parameters. diff --git a/flash/core/integrations/labelstudio/input.py b/flash/core/integrations/labelstudio/input.py index 43b7e16667..6d2ee49374 100644 --- a/flash/core/integrations/labelstudio/input.py +++ b/flash/core/integrations/labelstudio/input.py @@ -309,7 +309,6 @@ def load_data( def convert_to_encodedvideo(self, dataset): """Converting dataset to EncodedVideoDataset.""" if len(dataset) > 0: - from pytorchvideo.data import LabeledVideoDataset dataset = LabeledVideoDataset( @@ -342,7 +341,6 @@ def _parse_labelstudio_arguments( val_split: Optional[float] = None, multi_label: Optional[bool] = False, ): - train_data = None val_data = None test_data = None diff --git a/flash/core/integrations/pytorch_tabular/adapter.py b/flash/core/integrations/pytorch_tabular/adapter.py index 597e4cbd2b..21b95b6552 100644 --- a/flash/core/integrations/pytorch_tabular/adapter.py +++ b/flash/core/integrations/pytorch_tabular/adapter.py @@ -42,7 +42,6 @@ def from_task( metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]], backbone_kwargs: Optional[Dict[str, Any]] = None, ) -> Adapter: - backbone_kwargs = backbone_kwargs or {} parameters = { "embedding_dims": embedding_sizes, diff --git a/flash/core/model.py b/flash/core/model.py index 6eb385a8e0..1eda6705a3 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -92,8 +92,8 @@ def __setattr__(self, key, value): class DatasetProcessor: - """The ``DatasetProcessor`` mixin provides hooks for classes which need custom logic for producing the data - loaders for each running stage given the corresponding dataset.""" + """The ``DatasetProcessor`` mixin provides hooks for classes which need custom logic for producing the data loaders + for each running stage given the corresponding dataset.""" def __init__(self): super().__init__() @@ -477,8 +477,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return self(batch) def modules_to_freeze(self) -> Optional[nn.Module]: - """By default, we try to get the ``backbone`` attribute from the task and return it or ``None`` if not - present. + """By default, we try to get the ``backbone`` attribute from the task and return it or ``None`` if not present. Returns: The backbone ``Module`` to freeze or ``None`` if this task does not have a ``backbone`` attribute. @@ -543,7 +542,6 @@ def configure_finetune_callback( strategy: Union[str, Tuple[str, int], Tuple[str, Tuple[Tuple[int, int], int]], BaseFinetuning] = "no_freeze", train_bn: bool = True, ) -> List[BaseFinetuning]: - if isinstance(strategy, BaseFinetuning): return [strategy] @@ -573,8 +571,8 @@ def configure_finetune_callback( return [finetuning_strategy_fn(**finetuning_strategy_metadata)] def as_embedder(self, layer: str): - """Convert this task to an embedder. Note that the parameters are not copied so that any optimization of - the embedder will also apply to the converted ``Task``. + """Convert this task to an embedder. Note that the parameters are not copied so that any optimization of the + embedder will also apply to the converted ``Task``. Args: layer: The layer to embed to. This should be one of the :meth:`~flash.core.model.Task.available_layers`. diff --git a/flash/core/optimizers/lr_scheduler.py b/flash/core/optimizers/lr_scheduler.py index 9adb484948..497d1d087f 100644 --- a/flash/core/optimizers/lr_scheduler.py +++ b/flash/core/optimizers/lr_scheduler.py @@ -31,8 +31,8 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler): - """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr - and base_lr followed by a cosine annealing schedule between base_lr and eta_min. + """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and + base_lr followed by a cosine annealing schedule between base_lr and eta_min. .. warning:: It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` diff --git a/flash/core/registry.py b/flash/core/registry.py index 45d76c0e25..b7f921cc14 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -186,8 +186,7 @@ def available_keys(self) -> List[str]: class ExternalRegistry(FlashRegistry): - """The ``ExternalRegistry`` is a ``FlashRegistry`` that can point to an external provider via a getter - function. + """The ``ExternalRegistry`` is a ``FlashRegistry`` that can point to an external provider via a getter function. Args: getter: A function whose first argument is a key that can optionally take additional args and kwargs. @@ -213,8 +212,8 @@ def __init__( self.metadata = metadata def __contains__(self, item): - """Contains is always ``True`` for an ``ExternalRegistry`` as we can't know whether the getter will fail - without executing it.""" + """Contains is always ``True`` for an ``ExternalRegistry`` as we can't know whether the getter will fail without + executing it.""" return True def get( diff --git a/flash/core/regression.py b/flash/core/regression.py index 85c089f72b..7d3c500ef8 100644 --- a/flash/core/regression.py +++ b/flash/core/regression.py @@ -44,7 +44,6 @@ def __init__( metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, **kwargs, ) -> None: - metrics, loss_fn = RegressionMixin._build(loss_fn, metrics) super().__init__( diff --git a/flash/core/serve/composition.py b/flash/core/serve/composition.py index c1f84d4492..213e0e18de 100644 --- a/flash/core/serve/composition.py +++ b/flash/core/serve/composition.py @@ -17,7 +17,6 @@ def _parse_composition_kwargs( **kwargs: Union[ModelComponent, Endpoint] ) -> Tuple[Dict[str, ModelComponent], Dict[str, Endpoint]]: - components, endpoints = {}, {} for k, v in kwargs.items(): if isinstance(v, ModelComponent): diff --git a/flash/core/serve/dag/rewrite.py b/flash/core/serve/dag/rewrite.py index 993a09e447..8a01bea450 100644 --- a/flash/core/serve/dag/rewrite.py +++ b/flash/core/serve/dag/rewrite.py @@ -407,8 +407,8 @@ def _match(S, N): def _process_match(rule, syms): - """Process a match to determine if it is correct, and to find the correct substitution that will convert the - term into the pattern. + """Process a match to determine if it is correct, and to find the correct substitution that will convert the term + into the pattern. Parameters ---------- diff --git a/flash/core/serve/dag/task.py b/flash/core/serve/dag/task.py index 1cdf273447..b9d6696394 100644 --- a/flash/core/serve/dag/task.py +++ b/flash/core/serve/dag/task.py @@ -381,8 +381,7 @@ def getcycle(d, keys): def isdag(d, keys): - """Does graph form a directed acyclic graph when calculating keys? ``keys`` may be a single key or list of - keys. + """Does graph form a directed acyclic graph when calculating keys? ``keys`` may be a single key or list of keys. Examples -------- diff --git a/flash/core/trainer.py b/flash/core/trainer.py index a7697f33eb..6a9ab8b94a 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -127,8 +127,8 @@ def finetune( strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze", train_bn: bool = True, ): - r"""Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit`, but unfreezes layers - of the backbone throughout training layers of the backbone throughout training. + r"""Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit`, but unfreezes layers of + the backbone throughout training layers of the backbone throughout training. Args: model: Model to fit. @@ -219,8 +219,7 @@ def _resolve_callbacks( @staticmethod def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: - """This function keeps only 1 instance of each callback type, extending new_callbacks with - old_callbacks.""" + """This function keeps only 1 instance of each callback type, extending new_callbacks with old_callbacks.""" if len(new_callbacks) == 0: return old_callbacks new_callbacks_types = {type(c) for c in new_callbacks} @@ -246,8 +245,8 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> @property def estimated_stepping_batches(self) -> Union[int, float]: - """Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation - factor and distributed setup. + """Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor + and distributed setup. Examples ________ diff --git a/flash/core/utilities/lightning_cli.py b/flash/core/utilities/lightning_cli.py index c55206d5bf..b9d8224115 100644 --- a/flash/core/utilities/lightning_cli.py +++ b/flash/core/utilities/lightning_cli.py @@ -247,11 +247,11 @@ def __init__( subclass_mode_model: bool = False, subclass_mode_data: bool = False, ) -> None: - """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which - are called / instantiated using a parsed configuration file and / or command line args and then runs - trainer.fit. Parsing of configuration from environment variables can be enabled by setting - ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual - settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. + """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are + called / instantiated using a parsed configuration file and / or command line args and then runs trainer.fit. + Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full + configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables + named for example ``PL_TRAINER__MAX_EPOCHS``. Example, first implement the ``trainer.py`` tool as:: diff --git a/flash/core/utilities/stability.py b/flash/core/utilities/stability.py index c78636d25d..45241c853c 100644 --- a/flash/core/utilities/stability.py +++ b/flash/core/utilities/stability.py @@ -35,9 +35,9 @@ def _raise_beta_warning(message: str, stacklevel: int = 6): def beta(message: str = "This feature is currently in Beta."): - """The beta decorator is used to indicate that a particular feature is in Beta. A callable or type that has - been marked as beta will give a ``UserWarning`` when it is called or instantiated. This designation should be - used following the description given in :ref:`stability`. + """The beta decorator is used to indicate that a particular feature is in Beta. A callable or type that has been + marked as beta will give a ``UserWarning`` when it is called or instantiated. This designation should be used + following the description given in :ref:`stability`. Args: message: The message to include in the warning. diff --git a/flash/graph/collate.py b/flash/graph/collate.py index e356ce9235..aba0bcab96 100644 --- a/flash/graph/collate.py +++ b/flash/graph/collate.py @@ -23,8 +23,8 @@ def _pyg_collate(samples: List[Dict[str, Any]]) -> Dict[str, Any]: - """Helper to collate PyTorch Geometric ``Data`` objects into PyTorch Geometric ``Batch`` objects whilst - preserving our dictionary format.""" + """Helper to collate PyTorch Geometric ``Data`` objects into PyTorch Geometric ``Batch`` objects whilst preserving + our dictionary format.""" inputs = Batch.from_data_list([sample[DataKeys.INPUT] for sample in samples]) if DataKeys.TARGET in samples[0]: targets = default_collate([sample[DataKeys.TARGET] for sample in samples]) diff --git a/flash/graph/embedding/model.py b/flash/graph/embedding/model.py index 1f739f931b..33054eed1c 100644 --- a/flash/graph/embedding/model.py +++ b/flash/graph/embedding/model.py @@ -23,8 +23,8 @@ class GraphEmbedder(Task): - """The ``GraphEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from graphs. For - more details, see :ref:`graph_embedder`. + """The ``GraphEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from graphs. For more + details, see :ref:`graph_embedder`. Args: backbone: A model to use to extract image features. diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index a518bc8d68..b5949a369f 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -79,7 +79,6 @@ def forward(self, x): @beta("The Learn2Learn integration is currently in Beta.") class Learn2LearnAdapter(Adapter): - required_extras: str = "image" def __init__( @@ -102,8 +101,8 @@ def __init__( seed: int = 42, **algorithm_kwargs, ): - """The ``Learn2LearnAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with `learn 2 - learn` library (https://github.com/learnables/learn2learn). + """The ``Learn2LearnAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with `learn 2 learn` + library (https://github.com/learnables/learn2learn). Args: task: Task to be used. This adapter should work with any Flash Classification task @@ -216,7 +215,6 @@ def _convert_dataset( ) if isinstance(dataset, InputBase): - metadata = getattr(dataset, "data", None) if metadata is None or (metadata is not None and not isinstance(dataset.data, list)): raise TypeError("Only dataset built out of metadata is supported.") @@ -463,7 +461,6 @@ def process_predict_dataset( input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: - if not self._algorithm_has_validated: raise RuntimeError( "This training strategy needs to be validated before it can be used for prediction." diff --git a/flash/image/classification/backbones/resnet.py b/flash/image/classification/backbones/resnet.py index 80c70241a3..a3e1a21f73 100644 --- a/flash/image/classification/backbones/resnet.py +++ b/flash/image/classification/backbones/resnet.py @@ -175,7 +175,6 @@ def __init__( remove_first_maxpool: bool = False, in_chans: int = 3, ) -> None: - super().__init__() if norm_layer is None: @@ -315,7 +314,6 @@ def _resnet( weights_paths: dict = {"supervised": None}, **kwargs: Any, ) -> ResNet: - pretrained_flag = (pretrained and isinstance(pretrained, bool)) or (pretrained == "supervised") backbone = ResNet(block, layers, **kwargs) diff --git a/flash/image/classification/backbones/timm.py b/flash/image/classification/backbones/timm.py index ffdc71c39a..90928c964f 100644 --- a/flash/image/classification/backbones/timm.py +++ b/flash/image/classification/backbones/timm.py @@ -39,7 +39,6 @@ def _fn_timm( def register_timm_backbones(register: FlashRegistry): if _TIMM_AVAILABLE: for model_name in timm.list_models(): - if model_name in TORCHVISION_MODELS: continue diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 4726c0d3b7..df5ec1f89f 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -193,8 +193,7 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": - """Load the :class:`~flash.image.classification.data.ImageClassificationData` from folders containing - images. + """Load the :class:`~flash.image.classification.data.ImageClassificationData` from folders containing images. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. @@ -317,8 +316,8 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": - """Load the :class:`~flash.image.classification.data.ImageClassificationData` from numpy arrays (or lists - of arrays) and corresponding lists of targets. + """Load the :class:`~flash.image.classification.data.ImageClassificationData` from numpy arrays (or lists of + arrays) and corresponding lists of targets. The targets can be in any of our :ref:`supported classification target formats `. @@ -494,8 +493,8 @@ def from_tensors( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": - """Load the :class:`~flash.image.classification.data.ImageClassificationData` from torch tensors (or lists - of tensors) and corresponding lists of targets. + """Load the :class:`~flash.image.classification.data.ImageClassificationData` from torch tensors (or lists of + tensors) and corresponding lists of targets. The targets can be in any of our :ref:`supported classification target formats `. @@ -731,8 +730,8 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ImageClassificationData": - """Load the :class:`~flash.image.classification.data.ImageClassificationData` from CSV files containing - image file paths and their corresponding targets. + """Load the :class:`~flash.image.classification.data.ImageClassificationData` from CSV files containing image + file paths and their corresponding targets. Input images will be extracted from the ``input_field`` column in the CSV files. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, @@ -954,8 +953,8 @@ def from_fiftyone( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "ImageClassificationData": - """Load the :class:`~flash.image.classification.data.ImageClassificationData` from FiftyOne - ``SampleCollection`` objects. + """Load the :class:`~flash.image.classification.data.ImageClassificationData` from FiftyOne ``SampleCollection`` + objects. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. diff --git a/flash/image/detection/backbones.py b/flash/image/detection/backbones.py index 6bcfd5dd44..8873c3352b 100644 --- a/flash/image/detection/backbones.py +++ b/flash/image/detection/backbones.py @@ -103,7 +103,6 @@ def from_task( ) if module_available("effdet"): - model_type = icevision_models.ross.efficientdet OBJECT_DETECTION_HEADS( partial(load_icevision_with_image_size, model_type), diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index c4ba2fb772..b618db3f7b 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -84,8 +84,8 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data list of - image files, bounding boxes, and targets. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data list of image + files, bounding boxes, and targets. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. @@ -217,8 +217,8 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given from numpy - arrays (or lists of arrays) and corresponding lists of bounding boxes and targets. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given from numpy arrays + (or lists of arrays) and corresponding lists of bounding boxes and targets. The targets can be in any of our :ref:`supported classification target formats `. @@ -335,8 +335,8 @@ def from_images( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given lists of PIL - images and corresponding lists of bounding boxes and targets. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given lists of PIL images + and corresponding lists of bounding boxes and targets. The targets can be in any of our :ref:`supported classification target formats `. @@ -458,8 +458,8 @@ def from_tensors( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given from torch - tensors (or lists of tensors) and corresponding lists of bounding boxes and targets. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given from torch tensors + (or lists of tensors) and corresponding lists of bounding boxes and targets. The targets can be in any of our :ref:`supported classification target formats `. @@ -576,7 +576,6 @@ def from_icedata( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "ObjectDetectionData": - ds_kw = dict(parser=parser) return cls( @@ -925,10 +924,9 @@ def from_via( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders - and annotation files in the VIA (`VGG Image Annotator `_) - `JSON format `_. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and + annotation files in the VIA (`VGG Image Annotator `_) `JSON + format `_. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -1078,8 +1076,7 @@ def from_fiftyone( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """Load the :class:`~flash.image.detection.data.ObjectDetectionData` from FiftyOne ``SampleCollection`` - objects. + """Load the :class:`~flash.image.detection.data.ObjectDetectionData` from FiftyOne ``SampleCollection`` objects. Targets will be extracted from the ``label_field`` in the ``SampleCollection`` objects. To learn how to customize the transforms applied for each stage, read our @@ -1186,8 +1183,8 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "DataModule": - """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders - This is currently support only for the predicting stage. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders This is + currently support only for the predicting stage. Args: predict_folder: The folder containing the predict data. diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 6a31a22def..a494f0d567 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -48,8 +48,8 @@ class ImageEmbedder(AdapterTask): - """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For - more details, see :ref:`image_embedder`. + """The ``ImageEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from images. For more + details, see :ref:`image_embedder`. Args: training_strategy: Training strategy from VISSL, @@ -160,8 +160,7 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, *args) -> @classmethod @requires("image", "vissl", "fairscale") def available_training_strategies(cls) -> List[str]: - """Get the list of available training strategies (passed to the ``training_strategy`` argument) for this - task. + """Get the list of available training strategies (passed to the ``training_strategy`` argument) for this task. Examples ________ diff --git a/flash/image/embedding/vissl/adapter.py b/flash/image/embedding/vissl/adapter.py index 49140d7a3e..b059463430 100644 --- a/flash/image/embedding/vissl/adapter.py +++ b/flash/image/embedding/vissl/adapter.py @@ -87,7 +87,6 @@ def __init__( loss_fn: ClassyLoss, hooks: List[ClassyHook], ) -> None: - Adapter.__init__(self) self.model_config = self.get_model_config_template() diff --git a/flash/image/embedding/vissl/transforms/multicrop.py b/flash/image/embedding/vissl/transforms/multicrop.py index 56e39cf12f..cddd57d9f7 100644 --- a/flash/image/embedding/vissl/transforms/multicrop.py +++ b/flash/image/embedding/vissl/transforms/multicrop.py @@ -28,8 +28,7 @@ @dataclass class StandardMultiCropSSLTransform(InputTransform): - """Convert a PIL image to Multi-resolution Crops. The input is a PIL image and output is the list of image - crops. + """Convert a PIL image to Multi-resolution Crops. The input is a PIL image and output is the list of image crops. This transform was proposed in SwAV - https://arxiv.org/abs/2006.09882 This transform can act as a base transform class for SimCLR, SwAV, and Barlow Twins from VISSL. diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index 84bb292977..9bbc6e2aed 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -41,7 +41,6 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "FaceDetectionData": - ds_kw = dict() return cls( @@ -63,7 +62,6 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "FaceDetectionData": - return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files), transform=predict_transform, @@ -80,7 +78,6 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "FaceDetectionData": - return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_folder), transform=predict_transform, diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index d6533ab074..e546137799 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -85,7 +85,6 @@ def from_icedata( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "InstanceSegmentationData": - ds_kw = dict(parser=parser) return cls( @@ -131,8 +130,8 @@ def from_coco( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ): - """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the - given data folders and annotation files in the `COCO JSON format `_. + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given + data folders and annotation files in the `COCO JSON format `_. For help understanding and using the COCO format, take a look at this tutorial: `Create COCO annotations from scratch `__. @@ -284,9 +283,10 @@ def from_voc( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ): - """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the - given data folders, mask folders, and annotation files in the `PASCAL VOC (Visual Object Challenge) XML - format `_. + """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given + data folders, mask folders, and annotation files in the `PASCAL VOC (Visual Object Challenge) XML format. + + `_. .. note:: All three arguments `*_folder`, `*_target_folder`, and `*_ann_folder` are needed to load data for a particular stage. diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index 32d0395a0c..d2d422ca79 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -85,7 +85,6 @@ def from_icedata( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "KeypointDetectionData": - ds_kw = dict(parser=parser) return cls( diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 0c73cc14fa..4a0a679b4d 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -23,7 +23,6 @@ SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") if _SEGMENTATION_MODELS_AVAILABLE: - ENCODERS = smp.encoders.get_encoder_names() def _load_smp_backbone(backbone: str, **_) -> str: diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 5387a8095d..4ac4dbd774 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -181,8 +181,8 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": - """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from folders containing image - files and folders containing mask files. + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from folders containing image files + and folders containing mask files. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. @@ -322,8 +322,8 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SemanticSegmentationData": - """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from numpy arrays containing - images (or lists of arrays) and corresponding numpy arrays containing masks (or lists of arrays). + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from numpy arrays containing images + (or lists of arrays) and corresponding numpy arrays containing masks (or lists of arrays). To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -496,8 +496,8 @@ def from_fiftyone( label_field: str = "ground_truth", **data_module_kwargs: Any, ) -> "SemanticSegmentationData": - """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from FiftyOne - ``SampleCollection`` objects. + """Load the :class:`~flash.image.segmentation.data.SemanticSegmentationData` from FiftyOne ``SampleCollection`` + objects. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 4886dade8f..33d040d05f 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -48,7 +48,6 @@ def _load_smp_head( in_channels: int = 3, **kwargs, ) -> nn.Module: - if head not in SMP_MODELS: raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}") diff --git a/flash/image/segmentation/output.py b/flash/image/segmentation/output.py index cecd2c796c..522e228ed6 100644 --- a/flash/image/segmentation/output.py +++ b/flash/image/segmentation/output.py @@ -54,8 +54,8 @@ @SEMANTIC_SEGMENTATION_OUTPUTS(name="labels") class SegmentationLabelsOutput(Output): - """A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in - the image for semantic segmentation tasks. + """A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in the + image for semantic segmentation tasks. Args: labels_map: A dictionary that map the labels ids to pixel intensities. @@ -70,8 +70,8 @@ def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, @staticmethod def labels_to_image(img_labels: Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> Tensor: - """Function that given an image with labels ids and their pixel intensity mapping, creates an RGB - representation for visualisation purposes.""" + """Function that given an image with labels ids and their pixel intensity mapping, creates an RGB representation + for visualisation purposes.""" assert len(img_labels.shape) == 2, img_labels.shape H, W = img_labels.shape out = torch.empty(3, H, W, dtype=torch.uint8) diff --git a/flash/image/style_transfer/backbones.py b/flash/image/style_transfer/backbones.py index 07c05f1ca1..1258e1264a 100644 --- a/flash/image/style_transfer/backbones.py +++ b/flash/image/style_transfer/backbones.py @@ -22,7 +22,6 @@ __all__ = ["STYLE_TRANSFER_BACKBONES"] if _PYSTICHE_AVAILABLE: - from pystiche import enc MLE_FN_PATTERN = re.compile(r"^(?P\w+?)_multi_layer_encoder$") diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 75349ed02d..fea37d99c5 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -33,8 +33,8 @@ @beta("Style transfer is currently in Beta.") class StyleTransferData(DataModule): - """The ``StyleTransferData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of - classmethods for loading data for image style transfer.""" + """The ``StyleTransferData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods + for loading data for image style transfer.""" input_transform_cls = StyleTransferInputTransform @@ -205,8 +205,7 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": - """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from numpy arrays (or lists of - arrays). + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from numpy arrays (or lists of arrays). To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index cb550343fa..a81049c33b 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -48,7 +48,6 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict( scans_folder_name=scans_folder_name, labels_folder_name=labels_folder_name, @@ -79,7 +78,6 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict( scans_folder_name=scans_folder_name, labels_folder_name=labels_folder_name, @@ -106,7 +104,6 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict() return cls( diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index 596fef9595..1ceea10ad8 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -69,7 +69,6 @@ def __init__( lambda_loss_bbox: float = 1.0, lambda_loss_dir: float = 1.0, ): - super().__init__( model=None, loss_fn=loss_fn, diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py index 8022a94e71..c549875597 100644 --- a/flash/pointcloud/detection/open3d_ml/app.py +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -20,7 +20,6 @@ from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: - from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer from open3d.visualization import gui diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py index 5489a32067..ce38bad11b 100644 --- a/flash/pointcloud/detection/open3d_ml/backbones.py +++ b/flash/pointcloud/detection/open3d_ml/backbones.py @@ -51,7 +51,6 @@ def __len__(self): def register_open_3d_ml(register: FlashRegistry): if _POINTCLOUD_AVAILABLE: - CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") def get_collate_fn(model) -> Callable: diff --git a/flash/pointcloud/detection/open3d_ml/input.py b/flash/pointcloud/detection/open3d_ml/input.py index beac5adac8..616fe1d211 100644 --- a/flash/pointcloud/detection/open3d_ml/input.py +++ b/flash/pointcloud/detection/open3d_ml/input.py @@ -162,7 +162,6 @@ def predict_load_sample(self, metadata: Dict[str, str]): class PointCloudObjectDetectorFoldersInput(Input): - loaders: Dict[PointCloudObjectDetectionDataFormat, Type[BasePointCloudObjectDetectorLoader]] = { PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader } @@ -229,7 +228,6 @@ def predict_load_data( return self.loader.predict_load_data(data, self) def predict_load_sample(self, metadata: Dict[str, str]) -> Any: - data, metadata = self.loader.predict_load_sample(metadata) input_transform_fn = getattr(self, "input_transform_fn", None) diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 2a6dc3bcd7..f07f1981e1 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -26,7 +26,6 @@ @beta("Point cloud segmentation is currently in Beta.") class PointCloudSegmentationData(DataModule): - input_transform_cls = InputTransform @classmethod @@ -41,7 +40,6 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict() return cls( @@ -63,7 +61,6 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict() return cls( @@ -85,7 +82,6 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict() return cls( diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index ca20096542..929bc93121 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -18,13 +18,11 @@ from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: - from open3d._ml3d.torch.dataloaders import TorchDataloader from open3d._ml3d.vis.visualizer import LabelLUT from open3d._ml3d.vis.visualizer import Visualizer as Open3dVisualizer else: - Open3dVisualizer = object diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py index 31d44e612e..e7dfc80cfa 100644 --- a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py +++ b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -21,7 +21,6 @@ from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: - from open3d._ml3d.datasets.utils import DataProcessing from open3d._ml3d.utils.config import Config @@ -44,7 +43,6 @@ def __init__( predicting=False, **kwargs, ): - super().__init__() self.name = "Dataset" diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index d6a46e5a1e..1db88edd4e 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -57,8 +57,8 @@ def from_data_frame( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularClassificationData": - """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given - data frames. + """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given data + frames. .. note:: @@ -196,8 +196,8 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularClassificationData": - """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given - CSV files. + """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given CSV + files. .. note:: @@ -534,8 +534,8 @@ def from_lists( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularClassificationData": - """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given - data (in the form of list of a tuple or a dictionary). + """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given data + (in the form of list of a tuple or a dictionary). .. note:: The ``categorical_fields``, ``numerical_fields``, and ``target_fields`` do not need to be provided if diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index f2e4158eef..abf4094ad0 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -69,8 +69,7 @@ def from_data_frame( persistent_workers: bool = True, **input_kwargs: Any, ) -> "TabularForecastingData": - """Creates a :class:`~flash.tabular.forecasting.data.TabularForecastingData` object from the given data - frames. + """Creates a :class:`~flash.tabular.forecasting.data.TabularForecastingData` object from the given data frames. .. note:: diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index db08a7d660..1fe2144205 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -24,7 +24,6 @@ class TabularForecaster(AdapterTask): - backbones: FlashRegistry = FlashRegistry("backbones") + PYTORCH_FORECASTING_BACKBONES required_extras: str = "tabular" diff --git a/flash/tabular/input.py b/flash/tabular/input.py index e994c65edd..26e315c019 100644 --- a/flash/tabular/input.py +++ b/flash/tabular/input.py @@ -60,7 +60,6 @@ def compute_parameters( numerical_fields: List[str], categorical_fields: List[str], ) -> Dict[str, Any]: - mean, std = _compute_normalization(train_data_frame, numerical_fields) codes = _generate_codes(train_data_frame, categorical_fields) diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index c082b1c0c4..d9c714ab08 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -55,8 +55,7 @@ def from_data_frame( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularRegressionData": - """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given data - frames. + """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given data frames. .. note:: @@ -371,8 +370,7 @@ def from_dicts( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularRegressionData": - """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given - dictionary. + """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given dictionary. .. note:: @@ -498,8 +496,8 @@ def from_lists( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TabularRegressionData": - """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given data (in - the form of list of a tuple or a dictionary). + """Creates a :class:`~flash.tabular.regression.data.TabularRegressionData` object from the given data (in the + form of list of a tuple or a dictionary). .. note:: diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 8dc81e7b01..42ec03f98c 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -136,8 +136,8 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TemplateData": - """This is our custom ``from_*`` method. It expects numpy ``Array`` objects and targets as input and - creates the ``TemplateData`` with them. + """This is our custom ``from_*`` method. It expects numpy ``Array`` objects and targets as input and creates the + ``TemplateData`` with them. Args: train_data: The numpy ``Array`` containing the train data. @@ -228,14 +228,13 @@ def num_features(self) -> Optional[int]: @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher`` - method.""" + """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher`` method.""" return TemplateVisualization(*args, **kwargs) class TemplateVisualization(BaseVisualization): - """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just - prints the data. + """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just prints + the data. If you want to provide a visualization with your task, you can override these hooks. """ diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index e25a2af354..e5d812f15f 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -23,8 +23,8 @@ class TemplateSKLearnClassifier(ClassificationTask): - """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that - classifies tabular data from scikit-learn. + """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies + tabular data from scikit-learn. Args: num_features: The number of features (elements) in the input data. @@ -106,8 +106,8 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - """For the predict step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` key from the - input and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" + """For the predict step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` key from the input + and forward it to the :meth:`~flash.core.model.Task.predict_step`.""" batch = batch[DataKeys.INPUT] return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) diff --git a/flash/text/classification/backbones/huggingface.py b/flash/text/classification/backbones/huggingface.py index b9381405fb..6d4971ac7e 100644 --- a/flash/text/classification/backbones/huggingface.py +++ b/flash/text/classification/backbones/huggingface.py @@ -33,7 +33,6 @@ def load_hugingface(backbone: str, num_classes: int): HUGGINGFACE_BACKBONES = FlashRegistry("backbones") if _TRANSFORMERS_AVAILABLE: - HUGGINGFACE_BACKBONES = ExternalRegistry( getter=load_hugingface, name="backbones", diff --git a/flash/text/classification/collate.py b/flash/text/classification/collate.py index bb4ec6be17..996fa425ab 100644 --- a/flash/text/classification/collate.py +++ b/flash/text/classification/collate.py @@ -19,7 +19,6 @@ @dataclass(unsafe_hash=True) class TextClassificationCollate(TransformersCollate): - max_length: int = 128 def tokenize(self, sample): diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 2261d1e833..09a7b59129 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -368,8 +368,8 @@ def from_parquet( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Load the :class:`~flash.text.classification.data.TextClassificationData` from PARQUET files containing - text snippets and their corresponding targets. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from PARQUET files containing text + snippets and their corresponding targets. Input text snippets will be extracted from the ``input_field`` column in the PARQUET files. The targets will be extracted from the ``target_fields`` in the PARQUET files and can be in any of our @@ -593,8 +593,8 @@ def from_data_frame( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TextClassificationData": - """Load the :class:`~flash.text.classification.data.TextClassificationData` from Pandas ``DataFrame`` - objects containing text snippets and their corresponding targets. + """Load the :class:`~flash.text.classification.data.TextClassificationData` from Pandas ``DataFrame`` objects + containing text snippets and their corresponding targets. Input text snippets will be extracted from the ``input_field`` column in the ``DataFrame`` objects. The targets will be extracted from the ``target_fields`` in the ``DataFrame`` objects and can be in any of our diff --git a/flash/text/embedding/model.py b/flash/text/embedding/model.py index d66adff0e1..a170265e28 100644 --- a/flash/text/embedding/model.py +++ b/flash/text/embedding/model.py @@ -37,8 +37,8 @@ class TextEmbedder(Task): - """The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation. - For more details, see `embeddings`. + """The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation. For + more details, see `embeddings`. You can change the backbone to any question answering model from `UKPLab/sentence-transformers `_ using the ``backbone`` diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index f0d90202f5..06b7380166 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -666,8 +666,8 @@ def from_dicts( answer_column_name: str = "answer", **data_module_kwargs: Any, ) -> "QuestionAnsweringData": - """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from Python dictionary - objects containing questions, contexts and their corresponding answers. + """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from Python dictionary objects + containing questions, contexts and their corresponding answers. Question snippets will be extracted from the ``question_column_name`` field in the dictionaries. Context snippets will be extracted from the ``context_column_name`` field in the dictionaries. diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 26adde0452..9405451375 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -54,8 +54,8 @@ class QuestionAnsweringTask(Task): - """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details, - see `question_answering`. + """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details, see + `question_answering`. You can change the backbone to any question answering model from `HuggingFace/transformers `_ using the ``backbone`` @@ -155,7 +155,6 @@ def __init__( ) def _generate_answers(self, pred_start_logits, pred_end_logits, examples): - all_predictions = collections.OrderedDict() if self.version_2_with_negative: scores_diff_json = collections.OrderedDict() diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 4f13901b0e..cb42e03f43 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -33,8 +33,8 @@ class SummarizationData(DataModule): - """The ``SummarizationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of - classmethods for loading data for text summarization.""" + """The ``SummarizationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods + for loading data for text summarization.""" input_transform_cls = InputTransform @@ -52,8 +52,8 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SummarizationData": - """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from CSV files containing - input text snippets and their corresponding target text snippets. + """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from CSV files containing input + text snippets and their corresponding target text snippets. Input text snippets will be extracted from the ``input_field`` column in the CSV files. Target text snippets will be extracted from the ``target_field`` column in the CSV files. @@ -216,8 +216,8 @@ def from_json( field: Optional[str] = None, **data_module_kwargs: Any, ) -> "SummarizationData": - """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from JSON files containing - input text snippets and their corresponding target text snippets. + """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from JSON files containing input + text snippets and their corresponding target text snippets. Input text snippets will be extracted from the ``input_field`` column in the JSON files. Target text snippets will be extracted from the ``target_field`` column in the JSON files. @@ -418,8 +418,8 @@ def from_lists( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SummarizationData": - """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from lists of input text - snippets and corresponding lists of target text snippets. + """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from lists of input text snippets + and corresponding lists of target text snippets. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 91a012c3b4..6aa2a4a76e 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -33,8 +33,8 @@ class TranslationData(DataModule): - """The ``TranslationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of - classmethods for loading data for text translation.""" + """The ``TranslationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods + for loading data for text translation.""" input_transform_cls = InputTransform @@ -52,8 +52,8 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TranslationData": - """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from CSV files containing input - text snippets and their corresponding target text snippets. + """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from CSV files containing input text + snippets and their corresponding target text snippets. Input text snippets will be extracted from the ``input_field`` column in the CSV files. Target text snippets will be extracted from the ``target_field`` column in the CSV files. @@ -214,8 +214,8 @@ def from_json( field: Optional[str] = None, **data_module_kwargs: Any, ) -> "TranslationData": - """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from JSON files containing input - text snippets and their corresponding target text snippets. + """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from JSON files containing input text + snippets and their corresponding target text snippets. Input text snippets will be extracted from the ``input_field`` column in the JSON files. Target text snippets will be extracted from the ``target_field`` column in the JSON files. @@ -322,8 +322,8 @@ def from_hf_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TranslationData": - """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from Hugging Face ``Dataset`` - objects containing input text snippets and their corresponding target text snippets. + """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from Hugging Face ``Dataset`` objects + containing input text snippets and their corresponding target text snippets. Input text snippets will be extracted from the ``input_field`` column in the ``Dataset`` objects. Target text snippets will be extracted from the ``target_field`` column in the ``Dataset`` objects. @@ -415,8 +415,8 @@ def from_lists( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "TranslationData": - """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from lists of input text snippets - and corresponding lists of target text snippets. + """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from lists of input text snippets and + corresponding lists of target text snippets. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 631020faec..a49217ad2b 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -238,8 +238,7 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "VideoClassificationData": - """Load the :class:`~flash.video.classification.data.VideoClassificationData` from folders containing - videos. + """Load the :class:`~flash.video.classification.data.VideoClassificationData` from folders containing videos. The supported file extensions are: ``.mp4``, and ``.avi``. For train, test, and validation data, the folders are expected to contain a sub-folder for each class. @@ -715,8 +714,8 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "VideoClassificationData": - """Load the :class:`~flash.video.classification.data.VideoClassificationData` from CSV files containing - video file paths and their corresponding targets. + """Load the :class:`~flash.video.classification.data.VideoClassificationData` from CSV files containing video + file paths and their corresponding targets. Input video file paths will be extracted from the ``input_field`` column in the CSV files. The supported file extensions are: ``.mp4``, and ``.avi``. @@ -989,8 +988,8 @@ def from_fiftyone( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "VideoClassificationData": - """Load the :class:`~flash.video.classification.data.VideoClassificationData` from FiftyOne - ``SampleCollection`` objects. + """Load the :class:`~flash.video.classification.data.VideoClassificationData` from FiftyOne ``SampleCollection`` + objects. The supported file extensions are: ``.mp4``, and ``.avi``. The targets will be extracted from the ``label_field`` in the ``SampleCollection`` objects and can be in any diff --git a/flash/video/classification/input_transform.py b/flash/video/classification/input_transform.py index 6e9fe45dee..a65d070a74 100644 --- a/flash/video/classification/input_transform.py +++ b/flash/video/classification/input_transform.py @@ -39,7 +39,6 @@ def normalize(x: Tensor) -> Tensor: @requires("video") @dataclass class VideoClassificationInputTransform(InputTransform): - image_size: int = 244 temporal_sub_sample: int = 8 mean: Tensor = torch.tensor([0.45, 0.45, 0.45]) diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index f2e132a572..ba84184174 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -49,7 +49,6 @@ @dataclass class ImageClassificationInputTransform(InputTransform): - image_size: Tuple[int, int] = (196, 196) mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index 3f4d302f53..95b33c3eae 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -86,7 +86,6 @@ def check_reset(self): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") class TestBaseViz: def test_base_viz(self, tmpdir): - seed_everything(42) tmpdir = Path(tmpdir) @@ -117,7 +116,6 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: num_tests = 10 for stage in ("train", "val", "test", "predict"): - for _ in range(num_tests): for fcn_name in _CALLBACK_FUNCS: dm.data_fetcher.reset() diff --git a/tests/core/serve/test_compat/test_cached_property.py b/tests/core/serve/test_compat/test_cached_property.py index 6f2a8a7970..ffa50b360d 100644 --- a/tests/core/serve/test_compat/test_cached_property.py +++ b/tests/core/serve/test_compat/test_cached_property.py @@ -146,7 +146,6 @@ class MyClass(metaclass=MyMeta): def test_reuse_different_names(): """Disallow this case because decorated function a would not be cached.""" with pytest.raises(RuntimeError): - # noinspection PyUnusedLocal class ReusedCachedProperty: # NOSONAR """Test class.""" diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index 276083149d..198ff5106a 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -45,7 +45,6 @@ def test_start_server_with_repeated_exposed(session_global_datadir, lightning_sq composit = Composition(comp=comp, TESTING=True, DEBUG=True) app = composit.serve(host="0.0.0.0", port=8000) with TestClient(app) as tc: - meta = tc.get("http://127.0.0.1:8000/classify/meta") assert meta.status_code == 200 with (session_global_datadir / "fish.jpg").open("rb") as f: diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 9e5a3d48c8..d1ce3083a4 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -78,7 +78,6 @@ def __getitem__(self, index: int) -> Tensor: class DummyOutputTransform(OutputTransform): - pass diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index d0bed3280c..4e358388b7 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -32,7 +32,6 @@ class TestGraphClassifier(TaskTester): - task = GraphClassifier task_kwargs = {"num_features": 1, "num_classes": 2} cli_command = "graph_classification" diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py index 49b104270c..e843c561c4 100644 --- a/tests/graph/embedding/test_model.py +++ b/tests/graph/embedding/test_model.py @@ -35,7 +35,6 @@ class TestGraphEmbedder(TaskTester): - task = GraphEmbedder task_args = (GCN(in_channels=1, hidden_channels=512, num_layers=4),) is_testing = _GRAPH_TESTING diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index 7757e51332..d7f86a6da5 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -95,7 +95,6 @@ def predict_dataloader(self): class BoringDataModule(LightningDataModule): - random_full: Dataset random_train: Subset random_val: Subset diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py index 51e89807a3..3f34b9ecc3 100644 --- a/tests/helpers/task_tester.py +++ b/tests/helpers/task_tester.py @@ -145,15 +145,15 @@ def _test_load_from_checkpoint_dependency_error(self): def _test_init_dependency_error(self): - """Tests that a ``ModuleNotFoundError`` is raised when the task is instantiated if the required dependencies - are not available.""" + """Tests that a ``ModuleNotFoundError`` is raised when the task is instantiated if the required dependencies are not + available.""" with pytest.raises(ModuleNotFoundError, match="Required dependencies not available."): _ = self.instantiated_task class TaskTesterMeta(ABCMeta): - """The ``TaskTesterMeta`` is a metaclass which attaches a suite of tests to classes that extend ``TaskTester`` - based on the configuration variables they define. + """The ``TaskTesterMeta`` is a metaclass which attaches a suite of tests to classes that extend ``TaskTester`` based + on the configuration variables they define. These tests will also be wrapped with the appropriate marks to skip them if the required dependencies are not available. diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 39ff6e9a04..e92bdac9fd 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -639,7 +639,6 @@ def test_from_bad_csv_no_image(bad_csv_no_image): def test_mixup(single_target_csv): @dataclass class MyTransform(ImageClassificationInputTransform): - alpha: float = 1.0 def mixup(self, batch): diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 0fc4f02c48..13bfb67c2d 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -69,7 +69,6 @@ def __getitem__(self, idx): class TestObjectDetector(TaskTester): - task = ObjectDetector task_kwargs = {"num_classes": 2} cli_command = "object_detection" diff --git a/tests/image/instance_segmentation/test_model.py b/tests/image/instance_segmentation/test_model.py index 2518f1a07a..5efd09e4b7 100644 --- a/tests/image/instance_segmentation/test_model.py +++ b/tests/image/instance_segmentation/test_model.py @@ -91,7 +91,6 @@ def coco_instances(tmpdir): class TestInstanceSegmentation(TaskTester): - task = InstanceSegmentation task_kwargs = {"num_classes": 2} cli_command = "instance_segmentation" diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py index eed0ad8448..556542c3aa 100644 --- a/tests/pointcloud/detection/test_model.py +++ b/tests/pointcloud/detection/test_model.py @@ -19,7 +19,6 @@ class TestPointCloudObjectDetector(TaskTester): - task = PointCloudObjectDetector task_args = (2,) cli_command = "pointcloud_detection" diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py index 301d67913b..2ec1163acb 100644 --- a/tests/pointcloud/segmentation/test_model.py +++ b/tests/pointcloud/segmentation/test_model.py @@ -20,7 +20,6 @@ class TestPointCloudSegmentation(TaskTester): - task = PointCloudSegmentation task_args = (2,) cli_command = "pointcloud_segmentation" diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 1ce1c5b54e..674772c9e4 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -28,7 +28,6 @@ class TestTabularClassifier(TaskTester): - task = TabularClassifier task_kwargs = { "parameters": {"categorical_fields": list(range(4))}, diff --git a/tests/tabular/forecasting/test_data.py b/tests/tabular/forecasting/test_data.py index 9814f5eec7..79683a65fa 100644 --- a/tests/tabular/forecasting/test_data.py +++ b/tests/tabular/forecasting/test_data.py @@ -22,8 +22,8 @@ @pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") @patch("flash.tabular.forecasting.input.TimeSeriesDataSet") def test_from_data_frame_time_series_data_set_single_call(patch_time_series_data_set): - """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected - parameters when called once with data for all stages.""" + """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected parameters + when called once with data for all stages.""" patch_time_series_data_set.return_value.get_parameters.return_value = {"test": None} train_data = MagicMock() @@ -51,8 +51,8 @@ def test_from_data_frame_time_series_data_set_single_call(patch_time_series_data @pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") @patch("flash.tabular.forecasting.input.TimeSeriesDataSet") def test_from_data_frame_time_series_data_set_multi_call(patch_time_series_data_set): - """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected - parameters when called separately for each stage.""" + """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected parameters + when called separately for each stage.""" patch_time_series_data_set.return_value.get_parameters.return_value = {"test": None} train_data = MagicMock() diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py index 1518d5e6ea..6dbfe32a55 100644 --- a/tests/tabular/regression/test_model.py +++ b/tests/tabular/regression/test_model.py @@ -27,7 +27,6 @@ class TestTabularRegressor(TaskTester): - task = TabularRegressor task_kwargs = { "parameters": {"categorical_fields": list(range(4))}, diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index f461511f18..05bf82b54e 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -30,7 +30,6 @@ class TestTextClassifier(TaskTester): - task = TextClassifier task_args = (2,) task_kwargs = {"backbone": TEST_BACKBONE} diff --git a/tests/text/question_answering/test_model.py b/tests/text/question_answering/test_model.py index eb9114aced..66446d736e 100644 --- a/tests/text/question_answering/test_model.py +++ b/tests/text/question_answering/test_model.py @@ -26,7 +26,6 @@ class TestQuestionAnsweringTask(TaskTester): - task = QuestionAnsweringTask task_kwargs = {"backbone": TEST_BACKBONE} cli_command = "question_answering" diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index 3ccf800f89..dddca4b269 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -27,7 +27,6 @@ class TestSummarizationTask(TaskTester): - task = SummarizationTask task_kwargs = { "backbone": TEST_BACKBONE, diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index bd5c01c7e7..b6ef8cbe42 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -27,7 +27,6 @@ class TestTranslationTask(TaskTester): - task = TranslationTask task_kwargs = { "backbone": TEST_BACKBONE, diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index d0f9e7615c..f182c01eaa 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -145,7 +145,6 @@ def mock_encoded_video_dataset_folder(tmpdir): @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_folder(tmpdir): with mock_encoded_video_dataset_folder(tmpdir) as (mock_folder, total_duration): - half_duration = total_duration / 2 - 1e-9 datamodule = VideoClassificationData.from_folders( @@ -169,7 +168,6 @@ def test_video_classifier_finetune_from_folder(tmpdir): @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_files(tmpdir): with mock_video_data_frame() as (mock_data_frame, total_duration): - half_duration = total_duration / 2 - 1e-9 datamodule = VideoClassificationData.from_files( @@ -194,7 +192,6 @@ def test_video_classifier_finetune_from_files(tmpdir): @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_data_frame(tmpdir): with mock_video_data_frame() as (mock_data_frame, total_duration): - half_duration = total_duration / 2 - 1e-9 datamodule = VideoClassificationData.from_data_frame( @@ -271,7 +268,6 @@ def test_video_classifier_predict_from_tensors(tmpdir): @pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_csv(tmpdir): with mock_video_csv_file(tmpdir) as (mock_csv, total_duration): - half_duration = total_duration / 2 - 1e-9 datamodule = VideoClassificationData.from_csv( @@ -301,7 +297,6 @@ def test_video_classifier_finetune_fiftyone(tmpdir): dir_name, total_duration, ): - half_duration = total_duration / 2 - 1e-9 train_dataset = fo.Dataset.from_dir( From 2831da9c480bb114eeadef3a90568f3921b2f85e Mon Sep 17 00:00:00 2001 From: Arjun Sharda <77706434+ArjunSharda@users.noreply.github.com> Date: Fri, 3 Mar 2023 19:37:29 -0600 Subject: [PATCH 06/39] fix grammar/punctuation issues in README.md (#1517 Update README.md grammar / punctuation changes --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bacc658a71..c485d01d1b 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ See [our installation guide](https://lightning-flash.readthedocs.io/en/latest/in ### Step 1. Load your data All data loading in Flash is performed via a `from_*` classmethod on a `DataModule`. -Which `DataModule` to use and which `from_*` methods are available depends on the task you want to perform. +To decide which `DataModule` to use and which `from_*` methods are available, it depends on the task you want to perform. For example, for image segmentation where your data is stored in folders, you would use the [`from_folders` method of the `SemanticSegmentationData` class](https://lightning-flash.readthedocs.io/en/latest/reference/semantic_segmentation.html#from-folders): ```py @@ -85,7 +85,7 @@ dm = SemanticSegmentationData.from_folders( Our tasks come loaded with pre-trained backbones and (where applicable) heads. You can view the available backbones to use with your task using [`available_backbones`](https://lightning-flash.readthedocs.io/en/latest/general/backbones.html). -Once you've chosen, create the model: +Once you've chosen one, create the model: ```py from flash.image import SemanticSegmentation @@ -119,7 +119,7 @@ trainer.save_checkpoint("semantic_segmentation_model.pt") ### Make predictions with Flash! -Serve in just 2 lines. +Serve in just 2 lines: ```py from flash.image import SemanticSegmentation @@ -174,7 +174,7 @@ In detail, the following methods are currently implemented: ### Flash Optimizers / Schedulers -With Flash, swapping among 40+ optimizers and 15 + schedulers recipes are simple. Find the list of available optimizers, schedulers as follows: +With Flash, swapping among 40+ optimizers and 15+ schedulers recipes are simple. Find the list of available optimizers, schedulers as follows: ```py ImageClassifier.available_optimizers() From 2ee1016c93cffbe104bd84cebc8910bb92b60fbf Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Mar 2023 21:51:26 +0100 Subject: [PATCH 07/39] drop pep8 speaks --- .pep8speaks.yml | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 .pep8speaks.yml diff --git a/.pep8speaks.yml b/.pep8speaks.yml deleted file mode 100644 index 11a8164bac..0000000000 --- a/.pep8speaks.yml +++ /dev/null @@ -1,31 +0,0 @@ -# File : .pep8speaks.yml - -scanner: - diff_only: True # If False, the entire file touched by the Pull Request is scanned for errors. If True, only the diff is scanned. - linter: pycodestyle # Other option is flake8 - -pycodestyle: # Same as scanner.linter value. Other option is flake8 - max-line-length: 120 # Default is 79 in PEP 8 - ignore: # Errors and warnings to ignore - - W504 # line break after binary operator - - E402 # module level import not at top of file - - E731 # do not assign a lambda expression, use a def - - C406 # Unnecessary list literal - rewrite as a dict literal. - - E741 # ambiguous variable name - - F401 - - F841 - extend-ignore: E203, W503 - -no_blank_comment: True # If True, no comment is made on PR without any errors. -descending_issues_order: False # If True, PEP 8 issues in message will be displayed in descending order of line numbers in the file - -message: # Customize the comment made by the bot, - opened: # Messages when a new PR is submitted - header: "Hello @{name}! Thanks for opening this PR. " - # The keyword {name} is converted into the author's username - footer: "Do see the [Hitchhiker's guide to code style](https://goo.gl/hqbW4r)" - # The messages can be written as they would over GitHub - updated: # Messages when new commits are added to the PR - header: "Hello @{name}! Thanks for updating this PR. " - footer: "" # Why to comment the link to the style guide everytime? :) - no_errors: "There are currently no PEP 8 issues detected in this Pull Request. Cheers! :beers: " From 2625dcea244455c07e901157f7ffeb36c594ea49 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 6 Mar 2023 22:29:02 +0100 Subject: [PATCH 08/39] ci: prune & unify (#1526) * ci: prune & unify * drop --- .azure/gpu-special-tests.yml | 2 +- .azure/testing-template.yml | 6 +-- .github/workflows/ci-checks.yml | 20 ++++++++++ .github/workflows/ci-install-pkg.yml | 59 ---------------------------- .github/workflows/ci-schema.yml | 13 ------ .github/workflows/code-format.yml | 38 ------------------ .github/workflows/docs-check.yml | 46 ++++++++-------------- 7 files changed, 40 insertions(+), 144 deletions(-) create mode 100644 .github/workflows/ci-checks.yml delete mode 100644 .github/workflows/ci-install-pkg.yml delete mode 100644 .github/workflows/ci-schema.yml delete mode 100644 .github/workflows/code-format.yml diff --git a/.azure/gpu-special-tests.yml b/.azure/gpu-special-tests.yml index d16c5d6834..a3e33d7d63 100644 --- a/.azure/gpu-special-tests.yml +++ b/.azure/gpu-special-tests.yml @@ -20,7 +20,7 @@ jobs: # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" - pool: lit-rtx-3090 + pool: "lit-rtx-3090" variables: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) diff --git a/.azure/testing-template.yml b/.azure/testing-template.yml index 6bfd8b2201..4bd0e9006f 100644 --- a/.azure/testing-template.yml +++ b/.azure/testing-template.yml @@ -3,11 +3,11 @@ jobs: - job: displayName: "domain ${{dom}} with 2 GPU" # how long to run the job before automatically cancelling - timeoutInMinutes: 45 + timeoutInMinutes: "45" # how much time to give 'run always even if cancelled tasks' before stopping them - cancelTimeoutInMinutes: 2 + cancelTimeoutInMinutes: "2" - pool: lit-rtx-3090 + pool: "lit-rtx-3090" variables: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml new file mode 100644 index 0000000000..f581d407cf --- /dev/null +++ b/.github/workflows/ci-checks.yml @@ -0,0 +1,20 @@ +name: General Checks + +on: + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + +jobs: + check-schema: + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.7.1 + with: + azure-dir: '' # ToDo + + check-package: + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.7.1 + with: + actions-ref: v0.7.1 + artifact-name: dist-packages-${{ github.sha }} + import-name: "flash" diff --git a/.github/workflows/ci-install-pkg.yml b/.github/workflows/ci-install-pkg.yml deleted file mode 100644 index c5502104e4..0000000000 --- a/.github/workflows/ci-install-pkg.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: Install package - -# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows -on: # Trigger the workflow on push or pull request, but only for the master branch - push: - branches: [master] - pull_request: - branches: [master] - -jobs: - pkg-check: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: 3.7 - - name: Create package - run: | - # python setup.py check --metadata --strict - python setup.py sdist - - name: Check package - run: | - pip install twine==3.2 - twine check dist/* - python setup.py clean - - pkg-install: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - # max-parallel: 6 - matrix: - # PyTorch 1.5 is failing on Win and bolts requires torchvision>=0.5 - os: [ubuntu-20.04, macOS-12] # , windows-2022 - # fixme - python-version: [3.7] # , 3.8 - - steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Create package - run: | - # python setup.py check --metadata --strict - python setup.py sdist - - name: Install package - run: | - pip install virtualenv - virtualenv vEnv - source vEnv/bin/activate - pip install dist/* - cd .. & python -c "import pytorch_lightning as pl ; print(pl.__version__)" - cd .. & python -c "import flash ; print(flash.__version__)" - deactivate - rm -rf vEnv diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml deleted file mode 100644 index 6ed51d72a4..0000000000 --- a/.github/workflows/ci-schema.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: Check schema - -on: - push: - branches: [master, "release/*"] - pull_request: - branches: [master, "release/*"] - -jobs: - check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.7.1 - with: - azure-dir: '' diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml deleted file mode 100644 index 8dd38eda2c..0000000000 --- a/.github/workflows/code-format.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Check Code formatting - -# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows -on: # Trigger the workflow on push or pull request, but only for the master branch - push: {} - pull_request: - branches: [master] - -jobs: - pep8-check-flake8: - runs-on: ubuntu-20.04 - steps: - - name: Checkout - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: 3.7 - - name: Install dependencies - run: | - pip install flake8 - pip list - shell: bash - - name: PEP8 - run: flake8 . - - #typing-check-mypy: - # runs-on: ubuntu-20.04 - # steps: - # - uses: actions/checkout@master - # - uses: actions/setup-python@v4 - # with: - # python-version: 3.8 - # - name: Install mypy - # run: | - # pip install mypy - # pip list - # - name: mypy - # run: mypy diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 70858ef33a..3fc2f43767 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -24,37 +24,29 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- + key: pip-${{ hashFiles('requirements.txt') }} + restore-keys: pip- - name: Install dependencies run: | - sudo apt-get update - sudo apt-get install -y cmake pandoc - python --version + sudo apt-get update --fix-missing + # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux + sudo apt-get install -y cmake pandoc texlive-latex-extra dvipng texlive-pictures pip --version pip install . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip install --requirement requirements/docs.txt - # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux - sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures pip list shell: bash - name: Make Documentation - run: | - # First run the same pipeline as Read-The-Docs - cd docs - make clean - make html --debug --jobs 2 SPHINXOPTS="-W --keep-going" + working-directory: docs + run: make html --debug --jobs 2 SPHINXOPTS="-W --keep-going" - name: Upload built docs uses: actions/upload-artifact@v3 with: name: docs-results-${{ github.sha }} path: docs/build/html/ - # Use always() to always run this step to publish test results when there are test failures - if: success() test-docs: runs-on: ubuntu-20.04 @@ -73,20 +65,16 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }} - restore-keys: | - ${{ runner.os }}-pip- + key: pip-${{ hashFiles('requirements/base.txt') }} + restore-keys: pip- - name: Install dependencies run: | - sudo apt-get update - sudo apt-get install -y cmake pandoc - sudo apt-get install -y libsndfile1 - pip install '.[all]' --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - pip install '.[test]' --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - pip install --requirement requirements/docs.txt - python --version + sudo apt-get update --fix-missing + sudo apt-get install -y cmake pandoc libsndfile1 pip --version + pip install '.[all,test]' --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install --requirement requirements/docs.txt pip list shell: bash @@ -99,11 +87,9 @@ jobs: key: flash-datasets-docs - name: Test Documentation + working-directory: docs env: SPHINX_MOCK_REQUIREMENTS: 0 FIFTYONE_DO_NOT_TRACK: true - run: | - # First run the same pipeline as Read-The-Docs - apt-get update && sudo apt-get install -y cmake - cd docs - FLASH_TESTING=1 make doctest + FLASH_TESTING: 1 + run: make doctest From 1d1b16d6a8a7c7a716cb6df3fe59d8912bbb0210 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 6 Mar 2023 22:47:04 +0100 Subject: [PATCH 09/39] pkg: merge `setup_tools` to setup.py (#1527) * pkg: merge setup_tools --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- flash/setup_tools.py | 74 ---------------------------- setup.py | 113 +++++++++++++++++++++++++++++++------------ 2 files changed, 83 insertions(+), 104 deletions(-) delete mode 100644 flash/setup_tools.py diff --git a/flash/setup_tools.py b/flash/setup_tools.py deleted file mode 100644 index 1dfcd9056f..0000000000 --- a/flash/setup_tools.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import re -from typing import List - -_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) - - -def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_chars: str = "#@") -> List[str]: - with open(os.path.join(path_dir, file_name)) as file: - lines = [ln.strip() for ln in file.readlines()] - reqs = [] - for ln in lines: - # filer all comments - found = [ln.index(ch) for ch in comment_chars if ch in ln] - if found: - ln = ln[: min(found)].strip() - # skip directly installed dependencies - if ln.startswith("http") or ln.startswith("git") or ln.startswith("-r"): - continue - if ln: # if requirement is not empty - reqs.append(ln) - return reqs - - -def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: - """Load readme as decribtion. - - >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - '
...' - """ - path_readme = os.path.join(path_dir, "README.md") - text = open(path_readme, encoding="utf-8").read() - - # drop images from readme - text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "") - - # https://github.com/Lightning-AI/lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png - github_source_url = os.path.join(homepage, "raw", ver) - # replace relative repository path to absolute link to the release - # do not replace all "docs" as in the readme we reger some other sources with particular path to docs - text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") - - # readthedocs badge - text = text.replace("badge/?version=stable", f"badge/?version={ver}") - text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{ver}") - # codecov badge - text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg") - # replace github badges for release ones - text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") - - skip_begin = r"" - skip_end = r"" - # todo: wrap content as commented description - text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) - - # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png - # github_release_url = os.path.join(homepage, "releases", "download", ver) - # # download badge and replace url with local file - # text = _parse_for_badge(text, github_release_url) - return text diff --git a/setup.py b/setup.py index d213dec382..982617eb07 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ # limitations under the License. import glob import os +import re from functools import partial from importlib.util import module_from_spec, spec_from_file_location from itertools import chain @@ -26,6 +27,61 @@ _PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements") +def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: + """Load readme as decribtion. + + >>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + '
...' + """ + path_readme = os.path.join(path_dir, "README.md") + text = open(path_readme, encoding="utf-8").read() + + # drop images from readme + text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "") + + # https://github.com/Lightning-AI/lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png + github_source_url = os.path.join(homepage, "raw", ver) + # replace relative repository path to absolute link to the release + # do not replace all "docs" as in the readme we reger some other sources with particular path to docs + text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}") + + # readthedocs badge + text = text.replace("badge/?version=stable", f"badge/?version={ver}") + text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{ver}") + # codecov badge + text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg") + # replace github badges for release ones + text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") + + skip_begin = r"" + skip_end = r"" + # todo: wrap content as commented description + text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) + + # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png + # github_release_url = os.path.join(homepage, "releases", "download", ver) + # # download badge and replace url with local file + # text = _parse_for_badge(text, github_release_url) + return text + + +def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_chars: str = "#@") -> list: + with open(os.path.join(path_dir, file_name)) as file: + lines = [ln.strip() for ln in file.readlines()] + reqs = [] + for ln in lines: + # filer all comments + found = [ln.index(ch) for ch in comment_chars if ch in ln] + if found: + ln = ln[: min(found)].strip() + # skip directly installed dependencies + if ln.startswith("http") or ln.startswith("git") or ln.startswith("-r"): + continue + if ln: # if requirement is not empty + reqs.append(ln) + return reqs + + def _load_py_module(fname, pkg="flash"): spec = spec_from_file_location( os.path.join(pkg, fname), @@ -37,43 +93,40 @@ def _load_py_module(fname, pkg="flash"): about = _load_py_module("__about__.py") -setup_tools = _load_py_module("setup_tools.py") -long_description = setup_tools._load_readme_description( - _PATH_ROOT, - homepage=about.__homepage__, - ver=about.__version__, -) +long_description = _load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__) def _expand_reqs(extras: dict, keys: list) -> list: return list(chain(*[extras[ex] for ex in keys])) -base_req = setup_tools._load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt") # find all extra requirements -_load_req = partial(setup_tools._load_requirements, path_dir=_PATH_REQUIRE) -found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(_PATH_REQUIRE, "*.txt"))) -# remove datatype prefix -found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files] -# define basic and extra extras -extras_req = { - name: _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files) if "_" not in name -} -extras_req.update( - { - name: extras_req[name.split("_")[0]] + _load_req(file_name=fname) - for name, fname in zip(found_req_names, found_req_files) - if "_" in name +def _get_extras(path_dir: str = _PATH_REQUIRE): + _load_req = partial(_load_requirements, path_dir=path_dir) + found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(path_dir, "*.txt"))) + # remove datatype prefix + found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files] + # define basic and extra extras + extras_req = { + name: _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files) if "_" not in name } -) -# some extra combinations -extras_req["vision"] = _expand_reqs(extras_req, ["image", "video"]) -extras_req["core"] = _expand_reqs(extras_req, ["image", "tabular", "text"]) -extras_req["all"] = _expand_reqs(extras_req, ["vision", "tabular", "text", "audio"]) -extras_req["dev"] = _expand_reqs(extras_req, ["all", "test", "docs"]) -# filter the uniques -extras_req = {n: list(set(req)) for n, req in extras_req.items()} + extras_req.update( + { + name: extras_req[name.split("_")[0]] + _load_req(file_name=fname) + for name, fname in zip(found_req_names, found_req_files) + if "_" in name + } + ) + # some extra combinations + extras_req["vision"] = _expand_reqs(extras_req, ["image", "video"]) + extras_req["core"] = _expand_reqs(extras_req, ["image", "tabular", "text"]) + extras_req["all"] = _expand_reqs(extras_req, ["vision", "tabular", "text", "audio"]) + extras_req["dev"] = _expand_reqs(extras_req, ["all", "test", "docs"]) + # filter the uniques + extras_req = {n: list(set(req)) for n, req in extras_req.items()} + return extras_req + # https://packaging.python.org/discussions/install-requires-vs-requirements / # keep the meta-data here for simplicity in reading this file... it's not obvious @@ -93,14 +146,14 @@ def _expand_reqs(extras: dict, keys: list) -> list: long_description=long_description, long_description_content_type="text/markdown", include_package_data=True, - extras_require=extras_req, + extras_require=_get_extras(), entry_points={ "console_scripts": ["flash=flash.__main__:main"], }, zip_safe=False, keywords=["deep learning", "pytorch", "AI"], python_requires=">=3.7", - install_requires=base_req, + install_requires=_load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt"), project_urls={ "Bug Tracker": "https://github.com/Lightning-AI/lightning-flash/issues", "Documentation": "https://lightning-flash.rtfd.io/en/latest/", From 14693975b6e74885c54b01c483429aaa5209537f Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Mar 2023 00:51:37 +0100 Subject: [PATCH 10/39] update codeowners & docs --- .github/CODEOWNERS | 20 ++++++++++---------- docs/source/governance.rst | 14 +++++++------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index da8ff27410..c095e544a4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,23 +5,23 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @ethanwharris @borda @tchaton @justusschock @krshrimali @kaushikb11 +* @ethanwharris @borda @krshrimali # owners -/.github/CODEOWNERS @williamfalcon +/.github/CODEOWNERS @williamfalcon # main -/README.md @ethanwharris @krshrimali +/README.md @ethanwharris @borda # installation -/setup.py @borda @ethanwharris @krshrimali -/__about__.py @borda @ethanwharris @krshrimali -/__init__.py @borda @ethanwharris @krshrimali +/setup.py @borda @ethanwharris +/__about__.py @borda @ethanwharris +/__init__.py @borda @ethanwharris # CI/CD -/.github/workflows/ @borda @ethanwharris @krshrimali +/.github/workflows/ @borda @ethanwharris # configs in root -/*.yml @borda @ethanwharris @krshrimali +/*.yml @borda @ethanwharris # Docs -/.github/ISSUE_TEMPLATE/*.md @borda @ethanwharris @krshrimali -/docs/source/conf.py @borda @ethanwharris +/.github/ISSUE_TEMPLATE/*.md @borda @ethanwharris +/docs/source/conf.py @borda @ethanwharris /flash/core/integrations/labelstudio @KonstantinKorotaev @niklub diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 22594c8207..f6bc95a203 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -6,22 +6,22 @@ Flash Governance | Persons of interest Leads ----- - Ethan Harris (`ethanwharris `_) -- Kushashwa Ravi Shrimali (`krshrimali `_) -- Thomas Chaton (`tchaton `_) Core Maintainers ---------------- -- William Falcon (`williamFalcon `_) - Jirka Borovec (`Borda `_) -- Kaushik Bokka (`kaushikb11 `_) - Justus Schock (`justusschock `_) -- Akihiro Nitta (`akihironitta `_) - Aniket Maurya (`aniketmaurya `_) -- Sivaraman Karthik Rangasai (`karthikrangasai `_) - Pietro Lesci (`pietrolesci `_) Alumni ------ -- Sean Narenthiran (`SeanNaren `_) +- Akihiro Nitta (`akihironitta `_) - Ananya Harsh Jha (`ananyahjha93 `_) +- Kaushik Bokka (`kaushikb11 `_) +- Kushashwa Ravi Shrimali (`krshrimali `_) +- Sean Narenthiran (`SeanNaren `_) +- Sivaraman Karthik Rangasai (`karthikrangasai `_) +- Thomas Chaton (`tchaton `_) +- William Falcon (`williamFalcon `_) From 12cb417e6530fd4a8649d64e558a567bc31c156f Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 7 Mar 2023 12:37:51 +0100 Subject: [PATCH 11/39] pkg/ci: freeze requirements (#1528) * setup frezee * req. freeze * docs + note * guardian * ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .azure/gpu-special-tests.yml | 5 +- .azure/testing-template.yml | 2 + .github/workflows/ci-testing.yml | 62 ++++++++------- .github/workflows/docs-check.yml | 3 + .github/workflows/pypi-release.yml | 5 +- Makefile | 2 +- docs/source/governance.rst | 4 +- flash/image/detection/data.py | 16 ++-- flash/image/instance_segmentation/data.py | 11 ++- requirements.txt | 14 ++-- requirements/datatype_audio.txt | 12 +-- requirements/datatype_graph.txt | 14 ++-- requirements/datatype_image.txt | 22 +++--- requirements/datatype_image_extras.txt | 8 +- requirements/datatype_image_extras_baal.txt | 4 +- requirements/datatype_pointcloud.txt | 10 ++- requirements/datatype_tabular.txt | 8 +- requirements/datatype_text.txt | 20 ++--- requirements/datatype_video.txt | 8 +- requirements/datatype_video_extras.txt | 4 +- requirements/docs.txt | 28 ++++--- requirements/notebooks.txt | 3 - requirements/serve.txt | 24 +++--- requirements/test.txt | 15 +--- setup.py | 86 +++++++++++++-------- 25 files changed, 221 insertions(+), 169 deletions(-) delete mode 100644 requirements/notebooks.txt diff --git a/.azure/gpu-special-tests.yml b/.azure/gpu-special-tests.yml index a3e33d7d63..f0cc5a3392 100644 --- a/.azure/gpu-special-tests.yml +++ b/.azure/gpu-special-tests.yml @@ -53,9 +53,10 @@ jobs: - bash: | # python -m pip install "pip==20.1" - pip install '.[image]' learn2learn - pip install '.[test]' --upgrade-strategy only-if-needed + pip install '.[image,test]' learn2learn pip list + env: + FREEZE_REQUIREMENTS: 1 displayName: 'Install dependencies' - bash: | diff --git a/.azure/testing-template.yml b/.azure/testing-template.yml index 4bd0e9006f..5e6fb15a5a 100644 --- a/.azure/testing-template.yml +++ b/.azure/testing-template.yml @@ -50,6 +50,8 @@ jobs: fi pip install '.[test]' --upgrade-strategy only-if-needed pip list + env: + FREEZE_REQUIREMENTS: 1 displayName: 'Install dependencies' - bash: | diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 08cf84744c..d6884eefff 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -3,9 +3,9 @@ name: CI testing # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows on: # Trigger the workflow on push or pull request, but only for the master branch push: - branches: [master] + branches: ["master", "release/*"] pull_request: - branches: [master] + branches: ["master", "release/*"] concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} @@ -16,7 +16,7 @@ defaults: shell: bash jobs: - pytest: + pytester: runs-on: ${{ matrix.os }} strategy: @@ -34,21 +34,23 @@ jobs: - { os: ubuntu-20.04, python-version: 3.9, requires: 'oldest' } - { os: ubuntu-20.04, python-version: 3.9, requires: 'latest' } include: - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'pre', topic: [ 'core' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'image','image_extras_baal' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'video','video_extras' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'tabular' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'text' ] } - - { os: 'ubuntu-20.04', python-version: 3.8, requires: 'latest', release: 'stable', topic: [ 'pointcloud' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'serve' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'graph' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, requires: 'latest', release: 'stable', topic: [ 'audio' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'pre', topic: [ 'core' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'image' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'image','image_extras' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'image','image_extras_baal' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'video' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'video','video_extras' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'tabular' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'text' ] } + - { os: 'ubuntu-20.04', python-version: 3.8, release: 'stable', topic: [ 'pointcloud' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'serve' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'graph' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'audio' ] } # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 + env: + FREEZE_REQUIREMENTS: 1 steps: - uses: actions/checkout@v3 @@ -84,7 +86,7 @@ jobs: shell: python - run: echo "period=$(python -c 'import time; days = time.time() / 60 / 60 / 24; print(int(days / 7))' 2>&1)" >> $GITHUB_OUTPUT - if: matrix.requires == 'latest' + if: matrix.requires != 'latest' id: times - name: Get pip cache dir @@ -96,8 +98,7 @@ jobs: with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-${{ matrix.python-version }}-td${{ steps.times.outputs.period }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }} - restore-keys: | - ${{ runner.os }}-${{ matrix.python-version }}-td${{ steps.times.outputs.period }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip- + restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-td${{ steps.times.outputs.period }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip- - name: Install graph test dependencies if: contains( matrix.topic , 'graph' ) @@ -147,15 +148,7 @@ jobs: FLASH_TEST_TOPIC: ${{ join(matrix.topic,',') }} FIFTYONE_DO_NOT_TRACK: true run: | - coverage run --source flash -m pytest flash tests --reruns 3 --reruns-delay 2 -v \ - --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - - - name: Upload pytest test results - uses: actions/upload-artifact@v3 - with: - name: pytest-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }} - path: junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - if: failure() + coverage run --source flash -m pytest flash tests --reruns 3 --reruns-delay 2 -v - name: Statistics if: success() @@ -172,3 +165,18 @@ jobs: env_vars: OS,PYTHON name: codecov-umbrella fail_ci_if_error: false + + + testing-guardian: + runs-on: ubuntu-latest + needs: pytester + if: always() + steps: + - run: echo "${{ needs.pytester.result }}" + - name: failing... + if: needs.pytester.result == 'failure' + run: exit 1 + - name: cancelled or skipped... + if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) + timeout-minutes: 1 + run: sleep 90 diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 3fc2f43767..46a7c910b5 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -6,6 +6,9 @@ on: # Trigger the workflow on push or pull request, but only for the master bran pull_request: branches: [master] +env: + FREEZE_REQUIREMENTS: 1 + jobs: make-docs: runs-on: ubuntu-20.04 diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 922554d4a8..3e71d514e1 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -20,11 +20,10 @@ jobs: python-version: 3.7 - name: Install dependencies - run: >- - pip install --user --upgrade setuptools wheel + run: pip install --user --upgrade setuptools wheel build - name: Build run: | - python setup.py sdist bdist_wheel + python -m build ls -lh dist/ # We do this, since failures on test.pypi aren't that bad diff --git a/Makefile b/Makefile index 39ffa062e5..689d1a241f 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ test: clean docs: clean git submodule update --init --recursive - pip install --quiet -r requirements/docs.txt + pip install . --quiet -r requirements/docs.txt python -m sphinx -b html -W --keep-going docs/source docs/build clean: diff --git a/docs/source/governance.rst b/docs/source/governance.rst index f6bc95a203..60f25ccb84 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -11,16 +11,16 @@ Core Maintainers ---------------- - Jirka Borovec (`Borda `_) - Justus Schock (`justusschock `_) -- Aniket Maurya (`aniketmaurya `_) -- Pietro Lesci (`pietrolesci `_) Alumni ------ - Akihiro Nitta (`akihironitta `_) +- Aniket Maurya (`aniketmaurya `_) - Ananya Harsh Jha (`ananyahjha93 `_) - Kaushik Bokka (`kaushikb11 `_) - Kushashwa Ravi Shrimali (`krshrimali `_) +- Pietro Lesci (`pietrolesci `_) - Sean Narenthiran (`SeanNaren `_) - Sivaraman Karthik Rangasai (`karthikrangasai `_) - Thomas Chaton (`tchaton `_) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index b618db3f7b..98ac401d73 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -621,10 +621,8 @@ def from_coco( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """.. _COCO: https://www.immersivelimit.com/tutorials/create-coco-annotations-from-scratch. - - Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders - and annotation files in the `COCO JSON format `_. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and + annotation files in the `COCO JSON format `_. For help understanding and using the COCO format, take a look at this tutorial: `Create COCO annotations from scratch `__. @@ -735,6 +733,8 @@ def from_coco( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") >>> os.remove("train_annotations.json") + + .. _COCO: https://www.immersivelimit.com/tutorials/create-coco-annotations-from-scratch """ return cls.from_icedata( train_folder=train_folder, @@ -767,10 +767,8 @@ def from_voc( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "ObjectDetectionData": - """.. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/ - - Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders - and annotation files in the `PASCAL VOC (Visual Object Challenge) XML format `_. + """Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders and + annotation files in the `PASCAL VOC (Visual Object Challenge) XML format `_. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -891,6 +889,8 @@ def from_voc( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") >>> shutil.rmtree("train_annotations") + + .. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/ """ return cls.from_icedata( train_folder=train_folder, diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index e546137799..476901a629 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -134,7 +134,7 @@ def from_coco( data folders and annotation files in the `COCO JSON format `_. For help understanding and using the COCO format, take a look at this tutorial: `Create COCO annotations from - scratch `__. + scratch `_. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -248,6 +248,8 @@ def from_coco( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") >>> os.remove("train_annotations.json") + + .. _COCO: https://www.immersivelimit.com/tutorials/create-coco-annotations-from-scratch """ return cls.from_icedata( train_folder=train_folder, @@ -284,9 +286,8 @@ def from_voc( **data_module_kwargs: Any, ): """Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the given - data folders, mask folders, and annotation files in the `PASCAL VOC (Visual Object Challenge) XML format. - - `_. + data folders, mask folders, and annotation files in the `PASCAL VOC `_ (Visual Object Challenge) XML + format. .. note:: All three arguments `*_folder`, `*_target_folder`, and `*_ann_folder` are needed to load data for a particular stage. @@ -428,6 +429,8 @@ def from_voc( >>> shutil.rmtree("train_masks") >>> shutil.rmtree("predict_folder") >>> shutil.rmtree("train_annotations") + + .. _PASCAL: http://host.robots.ox.ac.uk/pascal/VOC/ """ return cls.from_icedata( train_folder=train_folder, diff --git a/requirements.txt b/requirements.txt index ff50947e4d..3c42fa44e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ packaging setuptools<=59.5.0 # Prevent install bug with tensorboard -numpy<1.24 # freeze for using np.long +numpy<1.24 # strict - freeze for using np.long torch>=1.7.1 -torchmetrics>=0.5.0,!=0.5.1, <0.11.0 -pytorch-lightning>=1.3.6 +torchmetrics>0.5.1, <0.11.0 +pytorch-lightning>=1.3.6, <1.9.0 pyDeprecate -pandas>=1.1.0 +pandas>=1.1.0, <=1.5.2 jsonargparse[signatures]>=3.17.0, <=4.9.0 -click>=7.1.2 -protobuf<=3.20.1 +click>=7.1.2, <=8.1.3 +protobuf <=3.20.1 fsspec[http]>=2021.6.1,<=2022.7.1 -lightning-utilities>=0.3.0 +lightning-utilities >=0.4.1 diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 4db6c7765a..5844520660 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,5 +1,7 @@ -torchaudio -torchvision -librosa>=0.8.1 -transformers>=4.13.0 -datasets>=1.16.1 +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +torchaudio <=0.13.1 +torchvision <=0.14.1 +librosa>=0.8.1, <=0.9.2 +transformers>=4.13.0, <=4.25.1 +datasets>=1.16.1, <=2.8.0 diff --git a/requirements/datatype_graph.txt b/requirements/datatype_graph.txt index e764f0b109..ff65a1399c 100644 --- a/requirements/datatype_graph.txt +++ b/requirements/datatype_graph.txt @@ -1,6 +1,8 @@ -torch-scatter -torch-sparse -torch-geometric>=2.0.0 -torch-cluster -networkx -class-resolver>=0.3.2 +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +torch-scatter <=2.1.0 +torch-sparse <=0.6.16 +torch-geometric>=2.0.0, <=2.2.0 +torch-cluster <=1.6.0 +networkx <=2.8.8 +class-resolver>=0.3.2, <=0.3.10 diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 1d3aef6b53..1237792cd6 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -1,10 +1,12 @@ -torchvision -timm>=0.4.5 -lightning-bolts>=0.3.3 -Pillow>=7.2 -albumentations>=1.0 -pystiche==1.* -segmentation-models-pytorch>=0.2.0 -ftfy -regex -sahi<0.11 # Fixes compatibility with icevision +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +torchvision <=0.14.1 +timm>=0.4.5, <=0.4.12 +lightning-bolts>=0.3.3, <=0.6.0 +Pillow>=7.2, <=9.3.0 +albumentations>=1.0, <=1.3.0 +pystiche>=1.0.0, <=1.0.1 +segmentation-models-pytorch>=0.2.0, <=0.3.1 +ftfy <=6.1.1 +regex <=2022.10.31 +sahi <0.11 # strict - Fixes compatibility with icevision diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 5a5d6556f2..8b58b7aafa 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -1,4 +1,6 @@ -matplotlib +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +matplotlib <=3.6.2 fiftyone classy_vision vissl>=0.1.5 @@ -12,7 +14,7 @@ fastface fairscale # pin PL for testing, remove when fastface is updated -pytorch-lightning<1.5.0 +pytorch-lightning <1.5.0 torchmetrics<0.8.0 # pinned PL so we force a compatible TM version # effdet had an issue with PL 1.12, and icevision doesn't support effdet's latest version yet (0.3.0) -torch<1.12 +torch <1.12 diff --git a/requirements/datatype_image_extras_baal.txt b/requirements/datatype_image_extras_baal.txt index 37386a6359..8cfaa31311 100644 --- a/requirements/datatype_image_extras_baal.txt +++ b/requirements/datatype_image_extras_baal.txt @@ -1,2 +1,4 @@ +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + # This is a separate file, as baal integration is affected by vissl installation (conflicts) -baal>=1.3.2 +baal>=1.3.2, <=1.7.0 diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt index cc6437f44c..aa95459400 100644 --- a/requirements/datatype_pointcloud.txt +++ b/requirements/datatype_pointcloud.txt @@ -1,4 +1,6 @@ -open3d==0.13 -torch==1.7.1 -torchvision -tensorboard +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +open3d ==0.13 +torch ==1.7.1 +torchvision ==0.8.2 +tensorboard <=2.11.0 diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index b124b00994..e81f4b143a 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -1,5 +1,7 @@ -scikit-learn -pytorch-forecasting>=0.9.0 +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +scikit-learn <=1.2.0 +pytorch-forecasting>=0.9.0, <=0.10.3 pytorch-tabular==0.7.0 torchmetrics<0.8.0 # pytorch-tabular pins PL so we force a compatible TM version -omegaconf<=2.1.1 +omegaconf<=2.1.1, <=2.1.1 diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index c61d6fb591..22da582b8c 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,9 +1,11 @@ -torchvision -sentencepiece>=0.1.95 -filelock -transformers>=4.5 -torchmetrics[text]>=0.5.1 -datasets>=1.8 -sentence-transformers -ftfy -regex +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +torchvision <=0.14.1 +sentencepiece>=0.1.95, <=0.1.97 +filelock <=3.8.2 +transformers>=4.5, <=4.25.1 +torchmetrics[text]>=0.5.1, <0.11.0 +datasets>=1.8, <=2.8.0 +sentence-transformers <=2.2.2 +ftfy <=6.1.1 +regex <=2022.10.31 diff --git a/requirements/datatype_video.txt b/requirements/datatype_video.txt index f05034036d..9c125d4d34 100644 --- a/requirements/datatype_video.txt +++ b/requirements/datatype_video.txt @@ -1,4 +1,6 @@ -torchvision -Pillow>=7.2 -kornia>=0.5.1 +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +torchvision <=0.14.1 +Pillow>=7.2, <=9.3.0 +kornia>=0.5.1, <=0.6.9 pytorchvideo==0.1.2 diff --git a/requirements/datatype_video_extras.txt b/requirements/datatype_video_extras.txt index 00de5ca1d2..1c10853c29 100644 --- a/requirements/datatype_video_extras.txt +++ b/requirements/datatype_video_extras.txt @@ -1 +1,3 @@ -fiftyone +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +fiftyone <=0.18.0 diff --git a/requirements/docs.txt b/requirements/docs.txt index 7c8aaca419..7cec7ee7a7 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,16 +1,20 @@ -sphinx>=4.0,<5.0 -myst-parser>=0.15 -nbsphinx>=0.8.5 -ipython[notebook] -pandoc>=1.0 -docutils>=0.16 -sphinxcontrib-fulltoc>=1.0 +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +sphinx >=4.0, <5.0 +myst-parser >=0.15 +nbsphinx >=0.8.5, <=0.8.10 +nbformat <5.7.0 +ipython[notebook] <8.7.0 +pandoc >=1.0 +docutils >=0.16, <=0.19 +sphinxcontrib-fulltoc >=1.0, <=1.2.0 sphinxcontrib-mockautodoc +sphinx-autodoc-typehints >=1.0, <=1.22 +sphinx-paramlinks >=0.5.1, <0.5.4 +sphinx-togglebutton >=0.2 +sphinx-copybutton >=0.3 +jinja2 >=3.0.0, <3.1.0 + pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip -sphinx-autodoc-typehints>=1.0 -sphinx-paramlinks>=0.5.1 -sphinx-togglebutton>=0.2 -sphinx-copybutton>=0.3 -jinja2>=3.0.0,<3.1.0 -r ../_notebooks/.actions/requirements.txt diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt deleted file mode 100644 index 5ea3ab796f..0000000000 --- a/requirements/notebooks.txt +++ /dev/null @@ -1,3 +0,0 @@ -nbconvert -jupyter_client -jupyter diff --git a/requirements/serve.txt b/requirements/serve.txt index 2014e85e9d..792fb2e20e 100644 --- a/requirements/serve.txt +++ b/requirements/serve.txt @@ -1,12 +1,14 @@ -pillow -pyyaml -cytoolz -graphviz -tqdm -fastapi>=0.65.2 -pydantic>1.8.1 +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + +pillow <=9.3.0 +pyyaml <=6.0 +cytoolz <=0.12.1 +graphviz <=0.20.1 +tqdm <=4.64.1 +fastapi>=0.65.2, <=0.68.2 +pydantic>1.8.1, <=1.10.2 starlette==0.14.2 -uvicorn[standard]>=0.12.0 -aiofiles -jinja2>=3.0.0,<3.1.0 -torchvision +uvicorn[standard]>=0.12.0, <=0.20.0 +aiofiles <=22.1.0 +jinja2>=3.0.0, <3.1.0 +torchvision <=0.14.1 diff --git a/requirements/test.txt b/requirements/test.txt index 8b5899f7d3..b1344ce6a7 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,20 +1,13 @@ +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup + coverage codecov>=2.1 -pytest>=5.0,<7.0 -pytest-flake8 -flake8 +pytest>=5.0, <7.0 pytest-doctestplus>=0.9.0 pytest-rerunfailures>=10.0 pytest-forked -# install pkg -check-manifest -twine==3.2 - -# formatting -pre-commit -isort -#mypy scikit-learn pytest_mock +qpth <0.0.14 torch_optimizer diff --git a/setup.py b/setup.py index 982617eb07..bfca41797b 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ # http://blog.ionelmc.ro/2014/05/25/python-packaging/ _PATH_ROOT = os.path.dirname(__file__) _PATH_REQUIRE = os.path.join(_PATH_ROOT, "requirements") +_FREEZE_REQUIREMENTS = bool(int(os.environ.get("FREEZE_REQUIREMENTS", 0))) def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: @@ -36,9 +37,6 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: path_readme = os.path.join(path_dir, "README.md") text = open(path_readme, encoding="utf-8").read() - # drop images from readme - text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "") - # https://github.com/Lightning-AI/lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png github_source_url = os.path.join(homepage, "raw", ver) # replace relative repository path to absolute link to the release @@ -53,40 +51,64 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: # replace github badges for release ones text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") - skip_begin = r"" - skip_end = r"" - # todo: wrap content as commented description - text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) - - # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png - # github_release_url = os.path.join(homepage, "releases", "download", ver) - # # download badge and replace url with local file - # text = _parse_for_badge(text, github_release_url) return text -def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_chars: str = "#@") -> list: +def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True) -> str: + """Adjust the upper version contrains. + + >>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # anything", unfreeze=False) + 'arrow<=1.2.2,>=1.2.0' + >>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # strict", unfreeze=False) + 'arrow<=1.2.2,>=1.2.0 # strict' + >>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # my name", unfreeze=True) + 'arrow>=1.2.0' + >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze=True) + 'arrow>=1.2.0, <=1.2.2 # strict' + >>> _augment_requirement("arrow", unfreeze=True) + 'arrow' + """ + # filer all comments + if comment_char in ln: + comment = ln[ln.index(comment_char) :] + ln = ln[: ln.index(comment_char)] + is_strict = "strict" in comment + else: + is_strict = False + req = ln.strip() + # skip directly installed dependencies + if not req or any(c in req for c in ["http:", "https:", "@"]): + return "" + + # remove version restrictions unless they are strict + if unfreeze and "<" in req and not is_strict: + req = re.sub(r",? *<=? *[\d\.\*]+,? *", "", req).strip() + + # adding strict back to the comment + if is_strict: + req += " # strict" + + return req + + +def _load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: bool = not _FREEZE_REQUIREMENTS) -> list: + """Loading requirements from a file. + + >>> path_req = os.path.join(_PATH_ROOT, "requirements") + >>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ['sphinx>=4.0', ...] + """ with open(os.path.join(path_dir, file_name)) as file: lines = [ln.strip() for ln in file.readlines()] - reqs = [] - for ln in lines: - # filer all comments - found = [ln.index(ch) for ch in comment_chars if ch in ln] - if found: - ln = ln[: min(found)].strip() - # skip directly installed dependencies - if ln.startswith("http") or ln.startswith("git") or ln.startswith("-r"): - continue - if ln: # if requirement is not empty - reqs.append(ln) + reqs = [_augment_requirement(ln, unfreeze=unfreeze) for ln in lines] + reqs = [str(req) for req in reqs if req and not req.startswith("-r")] + # filter empty lines and containing @ which means redirect to some git/http + reqs = [req for req in reqs if not any(c in req for c in ["@", "http://", "https://"])] return reqs def _load_py_module(fname, pkg="flash"): - spec = spec_from_file_location( - os.path.join(pkg, fname), - os.path.join(_PATH_ROOT, pkg, fname), - ) + spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname)) py = module_from_spec(spec) spec.loader.exec_module(py) return py @@ -94,15 +116,13 @@ def _load_py_module(fname, pkg="flash"): about = _load_py_module("__about__.py") -long_description = _load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__) - def _expand_reqs(extras: dict, keys: list) -> list: return list(chain(*[extras[ex] for ex in keys])) # find all extra requirements -def _get_extras(path_dir: str = _PATH_REQUIRE): +def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict: _load_req = partial(_load_requirements, path_dir=path_dir) found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(path_dir, "*.txt"))) # remove datatype prefix @@ -143,10 +163,9 @@ def _get_extras(path_dir: str = _PATH_REQUIRE): download_url="https://github.com/Lightning-AI/lightning-flash", license=about.__license__, packages=find_packages(exclude=["tests", "tests.*"]), - long_description=long_description, + long_description=_load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__), long_description_content_type="text/markdown", include_package_data=True, - extras_require=_get_extras(), entry_points={ "console_scripts": ["flash=flash.__main__:main"], }, @@ -154,6 +173,7 @@ def _get_extras(path_dir: str = _PATH_REQUIRE): keywords=["deep learning", "pytorch", "AI"], python_requires=">=3.7", install_requires=_load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt"), + extras_require=_get_extras(), project_urls={ "Bug Tracker": "https://github.com/Lightning-AI/lightning-flash/issues", "Documentation": "https://lightning-flash.rtfd.io/en/latest/", From 67cea56fd2eb1cace7afbcc97fbf3c1d6c32e043 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 23 Mar 2023 19:29:42 +0100 Subject: [PATCH 12/39] Bump Lightning-AI/utilities from 0.7.1 to 0.8.0 (#1533) * Bump Lightning-AI/utilities from 0.7.1 to 0.8.0 Bumps [Lightning-AI/utilities](https://github.com/Lightning-AI/utilities) from 0.7.1 to 0.8.0. - [Release notes](https://github.com/Lightning-AI/utilities/releases) - [Changelog](https://github.com/Lightning-AI/utilities/blob/main/CHANGELOG.md) - [Commits](https://github.com/Lightning-AI/utilities/compare/v0.7.1...v0.8.0) --- updated-dependencies: - dependency-name: Lightning-AI/utilities dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Apply suggestions from code review --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .github/workflows/ci-checks.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index f581d407cf..66add3768c 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -8,13 +8,13 @@ on: jobs: check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.7.1 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.8.0 with: azure-dir: '' # ToDo check-package: - uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.7.1 + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.8.0 with: - actions-ref: v0.7.1 + actions-ref: v0.8.0 artifact-name: dist-packages-${{ github.sha }} import-name: "flash" From babc109a022275b3a654379f312e6f7399a9f4b8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 23 Mar 2023 20:37:17 +0100 Subject: [PATCH 13/39] ci: drop circleci --- .circleci/config.yml | 124 -------------------------------- .github/workflows/docs-link.yml | 13 ---- MANIFEST.in | 1 - docs/source/governance.rst | 2 +- setup.cfg | 1 - 5 files changed, 1 insertion(+), 140 deletions(-) delete mode 100755 .circleci/config.yml delete mode 100644 .github/workflows/docs-link.yml diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100755 index 9897e168da..0000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,124 +0,0 @@ -# Python CircleCI 2.1 configuration file. -version: 2.1 - -orbs: - gcp-gke: circleci/gcp-gke@1.4.0 - go: circleci/go@1.7.1 - codecov: codecov/codecov@1.1.0 - -trigger: - tags: - include: - - '*' - branches: - include: - - "master" - - "release/*" - - "refs/tags/*" - -pr: - - "master" - - "release/*" - -references: - - checkout_ml_testing: &checkout_ml_testing - run: - name: Checkout ml-testing-accelerators - command: | - git clone https://github.com/GoogleCloudPlatform/ml-testing-accelerators.git - cd ml-testing-accelerators - git fetch origin 5e88ac24f631c27045e62f0e8d5dfcf34e425e25:stable - git checkout stable - cd .. - - install_jsonnet: &install_jsonnet - run: - name: Install jsonnet - command: | - go install github.com/google/go-jsonnet/cmd/jsonnet@latest - - update_jsonnet: &update_jsonnet - run: - name: Update jsonnet - command: | - export PR_NUMBER=$(git ls-remote origin "pull/*/head" | grep -F -f <(git rev-parse HEAD) | awk -F'/' '{print $3}') - export SHA=$(git rev-parse --short HEAD) - python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; data = open(fname).read().replace('{PYTORCH_VERSION}', '$XLA_VER') - data = data.replace('{PYTHON_VERSION}', '$PYTHON_VER').replace('{PR_NUMBER}', '$PR_NUMBER').replace('{SHA}', '$SHA') ; open(fname, 'w').write(data)" - cat dockers/tpu-tests/tpu_test_cases.jsonnet - - deploy_cluster: &deploy_cluster - run: - name: Deploy the job on the kubernetes cluster - command: | - export PATH=$PATH:$HOME/go/bin - job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet | kubectl create -f -) && \ - job_name=${job_name#job.batch/} - job_name=${job_name% created} - pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') - echo "GKE pod name: $pod_name" - echo "Waiting on kubernetes job: $job_name" - i=0 && \ - # N checks spaced 30s apart = 900s total. - status_code=2 && \ - # Check on the job periodically. Set the status code depending on what - # happened to the job in Kubernetes. If we try MAX_CHECKS times and - # still the job hasn't finished, give up and return the starting - # non-zero status code. - printf "Waiting for job to finish: " && \ - while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \ - echo "Done waiting. Job status code: $status_code" && \ - kubectl logs -f $pod_name --container=train > /tmp/full_output.txt - if grep -q '' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '//'; else mv /tmp/full_output.txt xx00; fi && \ - # First portion is the test logs. Print these to Github Action stdout. - cat xx00 && \ - echo "Done with log retrieval attempt." && \ - exit $status_code - - stats: &stats - run: - name: Statistics - command: | - mv ./xx01 coverage.xml - -jobs: - - TPU-tests: - executor: - name: go/default - tag: '1.17' - docker: - - image: circleci/python:3.7 - environment: - - XLA_VER: 1.9 - - PYTHON_VER: 3.7 - - MAX_CHECKS: 1000 - - CHECK_SPEEP: 5 - steps: - - checkout - - go/install: - version: "1.17" - - *checkout_ml_testing - - gcp-gke/install - - gcp-gke/update-kubeconfig-with-credentials: - cluster: $GKE_CLUSTER - perform-login: true - - *install_jsonnet - - *update_jsonnet - - *deploy_cluster - - *stats - - codecov/upload: - file: coverage.xml - flags: tpu,pytest - upload_name: TPU-coverage - - - store_artifacts: - path: coverage.xml - - -workflows: - version: 2 - ci-runs: - jobs: - - TPU-tests diff --git a/.github/workflows/docs-link.yml b/.github/workflows/docs-link.yml deleted file mode 100644 index c75a2636a9..0000000000 --- a/.github/workflows/docs-link.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: "Add Docs Link" - -on: [status] - -jobs: - circleci_artifacts_redirector_job: - runs-on: ubuntu-latest - steps: - - uses: larsoner/circleci-artifacts-redirector-action@master - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - artifact-path: 0/html/index.html - circleci-jobs: build-Docs diff --git a/MANIFEST.in b/MANIFEST.in index 1f249755df..8aad84f3de 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -36,7 +36,6 @@ exclude *.yml prune .git prune .github -prune .circleci prune notebook* prune temp* prune test* diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 60f25ccb84..902beea85d 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -10,7 +10,6 @@ Leads Core Maintainers ---------------- - Jirka Borovec (`Borda `_) -- Justus Schock (`justusschock `_) Alumni ------ @@ -18,6 +17,7 @@ Alumni - Akihiro Nitta (`akihironitta `_) - Aniket Maurya (`aniketmaurya `_) - Ananya Harsh Jha (`ananyahjha93 `_) +- Justus Schock (`justusschock `_) - Kaushik Bokka (`kaushikb11 `_) - Kushashwa Ravi Shrimali (`krshrimali `_) - Pietro Lesci (`pietrolesci `_) diff --git a/setup.cfg b/setup.cfg index 726c254010..d55a737e65 100644 --- a/setup.cfg +++ b/setup.cfg @@ -69,7 +69,6 @@ ignore = *.yml .github .github/* - .circleci [mypy] From 433474a303ea61e44f5ecece2fdb9506b2538cd4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 24 Mar 2023 09:27:55 +0100 Subject: [PATCH 14/39] move package to `src/` (#1536) * src * examples * fix * docs * ci * mock * qpth * manifest * dag * xfail --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/labeler.yml | 20 +++++++------- .github/workflows/ci-testing.yml | 17 +++--------- .gitignore | 6 ++-- .readthedocs.yml | 2 ++ MANIFEST.in | 16 +++-------- README.md | 2 +- docs/source/conf.py | 2 +- docs/source/general/production.rst | 4 +-- docs/source/general/serve.rst | 2 +- docs/source/integrations/baal.rst | 2 +- docs/source/integrations/fiftyone.rst | 6 ++-- docs/source/integrations/learn2learn.rst | 2 +- .../integrations/pytorch_forecasting.rst | 2 +- .../source/reference/audio_classification.rst | 2 +- .../source/reference/graph_classification.rst | 2 +- docs/source/reference/graph_embedder.rst | 2 +- .../source/reference/image_classification.rst | 6 ++-- .../image_classification_multi_label.rst | 2 +- docs/source/reference/image_embedder.rst | 2 +- .../reference/instance_segmentation.rst | 2 +- docs/source/reference/keypoint_detection.rst | 2 +- docs/source/reference/object_detection.rst | 6 ++-- .../reference/pointcloud_object_detection.rst | 2 +- .../reference/pointcloud_segmentation.rst | 2 +- docs/source/reference/question_answering.rst | 2 +- .../reference/semantic_segmentation.rst | 6 ++-- docs/source/reference/speech_recognition.rst | 6 ++-- docs/source/reference/style_transfer.rst | 2 +- docs/source/reference/summarization.rst | 6 ++-- .../reference/tabular_classification.rst | 6 ++-- docs/source/reference/tabular_forecasting.rst | 2 +- docs/source/reference/template.rst | 2 +- docs/source/reference/text_classification.rst | 6 ++-- .../text_classification_multi_label.rst | 2 +- docs/source/reference/text_embedder.rst | 2 +- docs/source/reference/translation.rst | 6 ++-- .../source/reference/video_classification.rst | 2 +- docs/source/template/backbones.rst | 6 ++-- docs/source/template/data.rst | 26 +++++++++--------- docs/source/template/examples.rst | 6 ++-- docs/source/template/optional.rst | 6 ++-- docs/source/template/task.rst | 8 +++--- .../audio_classification.py | 0 .../face_detection.py | 0 .../graph_classification.py | 0 .../graph_embedder.py | 0 .../image_classification.py | 0 .../image_classification_multi_label.py | 0 .../image_embedder.py | 0 .../instance_segmentation.py | 0 .../image_classification_active_learning.py | 0 .../fiftyone/image_classification.py | 0 .../image_classification_fiftyone_datasets.py | 0 .../integrations/fiftyone/image_embedding.py | 0 .../integrations/fiftyone/object_detection.py | 0 .../labelstudio/image_classification.py | 0 .../labelstudio/text_classification.py | 0 .../labelstudio/video_classification.py | 0 .../image_classification_imagenette_mini.py | 0 .../tabular_forecasting_interpretable.py | 0 .../keypoint_detection.py | 0 .../object_detection.py | 0 .../pointcloud_detection.py | 0 .../pointcloud_segmentation.py | 0 .../question_answering.py | 0 .../semantic_segmentation.py | 0 .../serve/generic/boston_prediction/client.py | 0 .../boston_prediction/inference_server.py | 0 .../boston_prediction/requirements.txt | 0 .../serve/generic/detection/classes.txt | 0 .../serve/generic/detection/client.py | 0 .../serve/generic/detection/inference.py | 0 .../serve/generic/detection/input.jpg | Bin .../serve/image_classification}/__init__.py | 0 .../serve/image_classification/client.py | 0 .../image_classification/inference_server.py | 0 .../serve/object_detection/client.py | 0 .../object_detection/inference_server.py | 0 .../serve/semantic_segmentation/client.py | 0 .../semantic_segmentation/inference_server.py | 0 .../serve/speech_recognition/client.py | 0 .../speech_recognition/inference_server.py | 0 .../serve/summarization/client.py | 0 .../serve/summarization/inference_server.py | 0 .../serve/tabular_classification/client.py | 0 .../inference_server.py | 0 .../serve/text_classification/client.py | 0 .../text_classification/inference_server.py | 0 .../serve/translation/client.py | 0 .../serve/translation/inference_server.py | 0 .../speech_recognition.py | 0 .../style_transfer.py | 0 {flash_examples => examples}/summarization.py | 0 .../tabular_classification.py | 0 .../tabular_forecasting.py | 0 .../tabular_regression.py | 0 {flash_examples => examples}/template.py | 0 .../text_classification.py | 0 .../text_classification_multi_label.py | 0 {flash_examples => examples}/text_embedder.py | 0 {flash_examples => examples}/translation.py | 0 .../video_classification.py | 0 .../visualizations/pointcloud_detection.py | 0 .../visualizations/pointcloud_segmentation.py | 0 requirements.txt | 6 ++-- requirements/test.txt | 3 +- setup.cfg | 5 ++-- setup.py | 5 ++-- {flash => src/flash}/__about__.py | 0 {flash => src/flash}/__init__.py | 0 {flash => src/flash}/__main__.py | 0 {flash => src/flash}/assets/example.wav | Bin {flash => src/flash}/assets/fish.jpg | Bin {flash => src/flash}/assets/road.png | Bin {flash => src/flash}/assets/starry_night.jpg | Bin {flash => src/flash}/audio/__init__.py | 0 .../flash}/audio/classification/__init__.py | 0 .../flash}/audio/classification/cli.py | 0 .../flash}/audio/classification/data.py | 0 .../flash}/audio/classification/input.py | 0 .../audio/classification/input_transform.py | 0 .../audio/speech_recognition/__init__.py | 0 .../audio/speech_recognition/backbone.py | 0 .../flash}/audio/speech_recognition/cli.py | 0 .../audio/speech_recognition/collate.py | 0 .../flash}/audio/speech_recognition/data.py | 0 .../flash}/audio/speech_recognition/input.py | 0 .../flash}/audio/speech_recognition/model.py | 0 .../speech_recognition/output_transform.py | 0 .../core/data => src/flash/core}/__init__.py | 0 {flash => src/flash}/core/adapter.py | 0 {flash => src/flash}/core/classification.py | 0 .../io => src/flash/core/data}/__init__.py | 0 {flash => src/flash}/core/data/base_viz.py | 0 {flash => src/flash}/core/data/batch.py | 0 {flash => src/flash}/core/data/callback.py | 0 {flash => src/flash}/core/data/data_module.py | 0 .../flash/core/data/io}/__init__.py | 0 .../core/data/io/classification_input.py | 0 {flash => src/flash}/core/data/io/input.py | 0 .../flash}/core/data/io/input_transform.py | 0 {flash => src/flash}/core/data/io/output.py | 0 .../flash}/core/data/io/output_transform.py | 0 .../core/data/io/transform_predictions.py | 0 {flash => src/flash}/core/data/output.py | 0 {flash => src/flash}/core/data/properties.py | 0 {flash => src/flash}/core/data/splits.py | 0 {flash => src/flash}/core/data/transforms.py | 0 .../flash/core/data/utilities}/__init__.py | 0 .../core/data/utilities/classification.py | 0 .../flash}/core/data/utilities/collate.py | 0 .../flash}/core/data/utilities/data_frame.py | 0 .../flash}/core/data/utilities/loading.py | 0 .../flash}/core/data/utilities/paths.py | 0 .../flash}/core/data/utilities/samples.py | 0 .../flash}/core/data/utilities/sort.py | 0 {flash => src/flash}/core/data/utils.py | 0 {flash => src/flash}/core/finetuning.py | 0 {flash => src/flash}/core/heads.py | 0 {flash => src/flash}/core/hooks.py | 0 .../flash/core/integrations}/__init__.py | 0 .../core/integrations/fiftyone/__init__.py | 0 .../core/integrations/fiftyone/utils.py | 0 .../core/integrations/icevision}/__init__.py | 0 .../core/integrations/icevision/adapter.py | 0 .../core/integrations/icevision/backbones.py | 0 .../core/integrations/icevision/data.py | 0 .../core/integrations/icevision/transforms.py | 0 .../core/integrations/icevision/wrappers.py | 0 .../integrations/labelstudio}/__init__.py | 0 .../core/integrations/labelstudio/input.py | 0 .../integrations/labelstudio/visualizer.py | 0 .../pytorch_forecasting/__init__.py | 0 .../pytorch_forecasting/adapter.py | 0 .../pytorch_forecasting/backbones.py | 0 .../pytorch_forecasting/transforms.py | 0 .../integrations/pytorch_tabular}/__init__.py | 0 .../integrations/pytorch_tabular/adapter.py | 0 .../integrations/pytorch_tabular/backbones.py | 0 .../integrations/transformers}/__init__.py | 0 .../core/integrations/transformers/collate.py | 0 {flash => src/flash}/core/model.py | 0 .../flash}/core/optimizers/__init__.py | 0 {flash => src/flash}/core/optimizers/lamb.py | 0 {flash => src/flash}/core/optimizers/lars.py | 0 .../flash}/core/optimizers/lr_scheduler.py | 0 .../flash}/core/optimizers/optimizers.py | 0 .../flash}/core/optimizers/schedulers.py | 0 {flash => src/flash}/core/registry.py | 0 {flash => src/flash}/core/regression.py | 0 {flash => src/flash}/core/serve/__init__.py | 0 .../flash}/core/serve/_compat/__init__.py | 0 .../core/serve/_compat/cached_property.py | 0 {flash => src/flash}/core/serve/component.py | 0 .../flash}/core/serve/composition.py | 0 {flash => src/flash}/core/serve/core.py | 0 {flash => src/flash}/core/serve/dag/NOTICE | 0 .../flash/core/serve/dag}/__init__.py | 0 .../flash}/core/serve/dag/optimization.py | 0 {flash => src/flash}/core/serve/dag/order.py | 0 .../flash}/core/serve/dag/rewrite.py | 0 {flash => src/flash}/core/serve/dag/task.py | 0 {flash => src/flash}/core/serve/dag/utils.py | 0 .../flash}/core/serve/dag/utils_test.py | 0 .../flash}/core/serve/dag/visualize.py | 0 {flash => src/flash}/core/serve/decorators.py | 0 {flash => src/flash}/core/serve/execution.py | 0 .../flash}/core/serve/flash_components.py | 0 .../flash/core/serve/interfaces}/__init__.py | 0 .../flash}/core/serve/interfaces/http.py | 0 .../flash}/core/serve/interfaces/models.py | 0 .../serve/interfaces/templates}/__init__.py | 0 .../core/serve/interfaces/templates/dag.html | 0 {flash => src/flash}/core/serve/server.py | 0 .../flash}/core/serve/types/__init__.py | 0 {flash => src/flash}/core/serve/types/base.py | 0 {flash => src/flash}/core/serve/types/bbox.py | 0 .../flash}/core/serve/types/image.py | 0 .../flash}/core/serve/types/label.py | 0 .../flash}/core/serve/types/number.py | 0 .../flash}/core/serve/types/repeated.py | 0 .../flash}/core/serve/types/table.py | 0 {flash => src/flash}/core/serve/types/text.py | 0 {flash => src/flash}/core/serve/utils.py | 0 {flash => src/flash}/core/trainer.py | 0 .../flash/core/utilities}/__init__.py | 0 .../flash}/core/utilities/apply_func.py | 0 .../flash}/core/utilities/compatibility.py | 0 .../flash}/core/utilities/embedder.py | 0 .../flash}/core/utilities/flash_cli.py | 0 .../flash}/core/utilities/imports.py | 0 .../flash}/core/utilities/isinstance.py | 0 .../flash}/core/utilities/lightning_cli.py | 0 .../flash}/core/utilities/providers.py | 0 .../flash}/core/utilities/stability.py | 0 {flash => src/flash}/core/utilities/stages.py | 0 {flash => src/flash}/core/utilities/types.py | 0 .../flash}/core/utilities/url_error.py | 0 {flash => src/flash}/graph/__init__.py | 0 {flash => src/flash}/graph/backbones.py | 0 .../flash}/graph/classification/__init__.py | 0 .../flash}/graph/classification/cli.py | 0 .../flash}/graph/classification/data.py | 0 .../flash}/graph/classification/input.py | 0 .../graph/classification/input_transform.py | 0 .../flash}/graph/classification/model.py | 0 {flash => src/flash}/graph/collate.py | 0 .../flash}/graph/embedding/__init__.py | 0 {flash => src/flash}/graph/embedding/model.py | 0 {flash => src/flash}/image/__init__.py | 0 .../flash}/image/classification/__init__.py | 0 .../flash}/image/classification/adapters.py | 0 .../classification/backbones/__init__.py | 0 .../image/classification/backbones/clip.py | 0 .../image/classification/backbones/resnet.py | 0 .../image/classification/backbones/timm.py | 0 .../classification/backbones/torchvision.py | 0 .../classification/backbones/transformers.py | 0 .../flash}/image/classification/cli.py | 0 .../flash}/image/classification/data.py | 0 .../flash}/image/classification/heads.py | 0 .../flash}/image/classification/input.py | 0 .../image/classification/input_transform.py | 0 .../classification/integrations}/__init__.py | 0 .../integrations/baal/__init__.py | 0 .../classification/integrations/baal/data.py | 0 .../integrations/baal/dropout.py | 0 .../classification/integrations/baal/loop.py | 0 .../integrations/learn2learn.py | 0 .../flash}/image/classification/model.py | 0 {flash => src/flash}/image/data.py | 0 .../flash}/image/detection/__init__.py | 0 .../flash}/image/detection/backbones.py | 0 {flash => src/flash}/image/detection/cli.py | 0 {flash => src/flash}/image/detection/data.py | 0 {flash => src/flash}/image/detection/input.py | 0 {flash => src/flash}/image/detection/model.py | 0 .../flash}/image/detection/output.py | 0 .../flash}/image/embedding/__init__.py | 0 .../flash}/image/embedding/heads/__init__.py | 0 .../image/embedding/heads/vissl_heads.py | 0 .../flash}/image/embedding/losses/__init__.py | 0 .../image/embedding/losses/vissl_losses.py | 0 {flash => src/flash}/image/embedding/model.py | 0 .../image/embedding/strategies/__init__.py | 0 .../image/embedding/strategies/default.py | 0 .../embedding/strategies/vissl_strategies.py | 0 .../image/embedding/transforms/__init__.py | 0 .../embedding/transforms/vissl_transforms.py | 0 .../flash/image/embedding/vissl}/__init__.py | 0 .../flash}/image/embedding/vissl/adapter.py | 0 .../flash}/image/embedding/vissl/hooks.py | 0 .../embedding/vissl/transforms/__init__.py | 0 .../embedding/vissl/transforms/multicrop.py | 0 .../embedding/vissl/transforms/utilities.py | 0 .../flash}/image/face_detection/__init__.py | 0 .../face_detection/backbones/__init__.py | 0 .../backbones/fastface_backbones.py | 0 .../flash}/image/face_detection/cli.py | 0 .../flash}/image/face_detection/data.py | 0 .../flash}/image/face_detection/input.py | 0 .../image/face_detection/input_transform.py | 0 .../flash}/image/face_detection/model.py | 0 .../image/face_detection/output_transform.py | 0 .../image/instance_segmentation/__init__.py | 0 .../image/instance_segmentation/backbones.py | 0 .../flash}/image/instance_segmentation/cli.py | 0 .../image/instance_segmentation/data.py | 0 .../image/instance_segmentation/model.py | 0 .../image/keypoint_detection/__init__.py | 0 .../image/keypoint_detection/backbones.py | 0 .../flash}/image/keypoint_detection/cli.py | 0 .../flash}/image/keypoint_detection/data.py | 0 .../keypoint_detection/input_transform.py | 0 .../flash}/image/keypoint_detection/model.py | 0 .../flash}/image/segmentation/__init__.py | 0 .../flash}/image/segmentation/backbones.py | 0 .../flash}/image/segmentation/cli.py | 0 .../flash}/image/segmentation/data.py | 0 .../flash}/image/segmentation/heads.py | 0 .../flash}/image/segmentation/input.py | 0 .../image/segmentation/input_transform.py | 0 .../flash}/image/segmentation/model.py | 0 .../flash}/image/segmentation/output.py | 0 .../flash}/image/segmentation/viz.py | 0 .../flash}/image/style_transfer/__init__.py | 0 .../flash}/image/style_transfer/backbones.py | 0 .../flash}/image/style_transfer/cli.py | 0 .../flash}/image/style_transfer/data.py | 0 .../image/style_transfer/input_transform.py | 0 .../flash}/image/style_transfer/model.py | 0 .../flash}/image/style_transfer/utils.py | 0 {flash => src/flash}/pointcloud/__init__.py | 0 .../flash}/pointcloud/detection/__init__.py | 0 .../flash}/pointcloud/detection/backbones.py | 0 .../flash}/pointcloud/detection/cli.py | 0 .../flash}/pointcloud/detection/data.py | 0 .../flash}/pointcloud/detection/datasets.py | 0 .../flash}/pointcloud/detection/input.py | 0 .../flash}/pointcloud/detection/model.py | 0 .../detection}/open3d_ml/__init__.py | 0 .../pointcloud/detection/open3d_ml/app.py | 0 .../detection/open3d_ml/backbones.py | 0 .../pointcloud/detection/open3d_ml/input.py | 0 .../pointcloud/segmentation/__init__.py | 0 .../pointcloud/segmentation/backbones.py | 0 .../flash}/pointcloud/segmentation/cli.py | 0 .../flash}/pointcloud/segmentation/data.py | 0 .../pointcloud/segmentation/datasets.py | 0 .../flash}/pointcloud/segmentation/input.py | 0 .../flash}/pointcloud/segmentation/model.py | 0 .../segmentation/open3d_ml}/__init__.py | 0 .../pointcloud/segmentation/open3d_ml/app.py | 0 .../segmentation/open3d_ml/backbones.py | 0 .../open3d_ml/sequences_dataset.py | 0 {flash => src/flash}/tabular/__init__.py | 0 .../flash}/tabular/classification/__init__.py | 0 .../flash}/tabular/classification/cli.py | 0 .../flash}/tabular/classification/data.py | 0 .../flash}/tabular/classification/input.py | 0 .../flash}/tabular/classification/model.py | 0 .../flash}/tabular/classification/utils.py | 0 {flash => src/flash}/tabular/data.py | 0 .../flash}/tabular/forecasting/__init__.py | 0 .../flash}/tabular/forecasting/cli.py | 0 .../flash}/tabular/forecasting/data.py | 0 .../flash}/tabular/forecasting/input.py | 0 .../flash}/tabular/forecasting/model.py | 0 {flash => src/flash}/tabular/input.py | 0 .../flash}/tabular/regression/__init__.py | 0 .../flash}/tabular/regression/cli.py | 0 .../flash}/tabular/regression/data.py | 0 .../flash}/tabular/regression/input.py | 0 .../flash}/tabular/regression/model.py | 0 {flash => src/flash}/template/__init__.py | 0 .../template/classification/__init__.py | 0 .../template/classification/backbones.py | 0 .../flash}/template/classification/data.py | 0 .../flash}/template/classification/model.py | 0 {flash => src/flash}/text/__init__.py | 0 .../flash}/text/classification/__init__.py | 0 .../flash}/text/classification/adapters.py | 0 .../text/classification/backbones/__init__.py | 0 .../text/classification/backbones/clip.py | 0 .../classification/backbones/huggingface.py | 0 .../flash}/text/classification/cli.py | 0 .../flash}/text/classification/collate.py | 0 .../flash}/text/classification/data.py | 0 .../flash}/text/classification/input.py | 0 .../flash}/text/classification/model.py | 0 .../flash}/text/embedding/__init__.py | 0 .../flash}/text/embedding/backbones.py | 0 {flash => src/flash}/text/embedding/model.py | 0 {flash => src/flash}/text/input.py | 0 {flash => src/flash}/text/ort_callback.py | 0 .../text/question_answering/__init__.py | 0 .../flash}/text/question_answering/cli.py | 0 .../flash}/text/question_answering/collate.py | 0 .../flash}/text/question_answering/data.py | 0 .../flash}/text/question_answering/input.py | 0 .../flash}/text/question_answering/model.py | 0 .../question_answering/output_transform.py | 0 {flash => src/flash}/text/seq2seq/__init__.py | 0 .../flash}/text/seq2seq/core/__init__.py | 0 .../flash}/text/seq2seq/core/collate.py | 0 .../flash}/text/seq2seq/core/input.py | 0 .../flash}/text/seq2seq/core/model.py | 0 .../text/seq2seq/summarization/__init__.py | 0 .../flash}/text/seq2seq/summarization/cli.py | 0 .../flash}/text/seq2seq/summarization/data.py | 0 .../text/seq2seq/summarization/model.py | 0 .../text/seq2seq/translation/__init__.py | 0 .../flash}/text/seq2seq/translation/cli.py | 0 .../flash}/text/seq2seq/translation/data.py | 0 .../flash}/text/seq2seq/translation/model.py | 0 {flash => src/flash}/video/__init__.py | 0 .../flash/video/classification}/__init__.py | 0 .../flash}/video/classification/cli.py | 0 .../flash}/video/classification/data.py | 0 .../flash}/video/classification/input.py | 0 .../video/classification/input_transform.py | 0 .../flash}/video/classification/model.py | 0 .../flash}/video/classification/utils.py | 0 tests/core/serve/test_integration.py | 10 ++++++- tests/examples/test_integrations.py | 2 +- tests/examples/test_scripts.py | 2 +- tests/image/style_transfer/test_model.py | 3 ++ 427 files changed, 124 insertions(+), 127 deletions(-) rename {flash_examples => examples}/audio_classification.py (100%) rename {flash_examples => examples}/face_detection.py (100%) rename {flash_examples => examples}/graph_classification.py (100%) rename {flash_examples => examples}/graph_embedder.py (100%) rename {flash_examples => examples}/image_classification.py (100%) rename {flash_examples => examples}/image_classification_multi_label.py (100%) rename {flash_examples => examples}/image_embedder.py (100%) rename {flash_examples => examples}/instance_segmentation.py (100%) rename {flash_examples => examples}/integrations/baal/image_classification_active_learning.py (100%) rename {flash_examples => examples}/integrations/fiftyone/image_classification.py (100%) rename {flash_examples => examples}/integrations/fiftyone/image_classification_fiftyone_datasets.py (100%) rename {flash_examples => examples}/integrations/fiftyone/image_embedding.py (100%) rename {flash_examples => examples}/integrations/fiftyone/object_detection.py (100%) rename {flash_examples => examples}/integrations/labelstudio/image_classification.py (100%) rename {flash_examples => examples}/integrations/labelstudio/text_classification.py (100%) rename {flash_examples => examples}/integrations/labelstudio/video_classification.py (100%) rename {flash_examples => examples}/integrations/learn2learn/image_classification_imagenette_mini.py (100%) rename {flash_examples => examples}/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py (100%) rename {flash_examples => examples}/keypoint_detection.py (100%) rename {flash_examples => examples}/object_detection.py (100%) rename {flash_examples => examples}/pointcloud_detection.py (100%) rename {flash_examples => examples}/pointcloud_segmentation.py (100%) rename {flash_examples => examples}/question_answering.py (100%) rename {flash_examples => examples}/semantic_segmentation.py (100%) rename {flash_examples => examples}/serve/generic/boston_prediction/client.py (100%) rename {flash_examples => examples}/serve/generic/boston_prediction/inference_server.py (100%) rename {flash_examples => examples}/serve/generic/boston_prediction/requirements.txt (100%) rename {flash_examples => examples}/serve/generic/detection/classes.txt (100%) rename {flash_examples => examples}/serve/generic/detection/client.py (100%) rename {flash_examples => examples}/serve/generic/detection/inference.py (100%) rename {flash_examples => examples}/serve/generic/detection/input.jpg (100%) rename {flash/core => examples/serve/image_classification}/__init__.py (100%) rename {flash_examples => examples}/serve/image_classification/client.py (100%) rename {flash_examples => examples}/serve/image_classification/inference_server.py (100%) rename {flash_examples => examples}/serve/object_detection/client.py (100%) rename {flash_examples => examples}/serve/object_detection/inference_server.py (100%) rename {flash_examples => examples}/serve/semantic_segmentation/client.py (100%) rename {flash_examples => examples}/serve/semantic_segmentation/inference_server.py (100%) rename {flash_examples => examples}/serve/speech_recognition/client.py (100%) rename {flash_examples => examples}/serve/speech_recognition/inference_server.py (100%) rename {flash_examples => examples}/serve/summarization/client.py (100%) rename {flash_examples => examples}/serve/summarization/inference_server.py (100%) rename {flash_examples => examples}/serve/tabular_classification/client.py (100%) rename {flash_examples => examples}/serve/tabular_classification/inference_server.py (100%) rename {flash_examples => examples}/serve/text_classification/client.py (100%) rename {flash_examples => examples}/serve/text_classification/inference_server.py (100%) rename {flash_examples => examples}/serve/translation/client.py (100%) rename {flash_examples => examples}/serve/translation/inference_server.py (100%) rename {flash_examples => examples}/speech_recognition.py (100%) rename {flash_examples => examples}/style_transfer.py (100%) rename {flash_examples => examples}/summarization.py (100%) rename {flash_examples => examples}/tabular_classification.py (100%) rename {flash_examples => examples}/tabular_forecasting.py (100%) rename {flash_examples => examples}/tabular_regression.py (100%) rename {flash_examples => examples}/template.py (100%) rename {flash_examples => examples}/text_classification.py (100%) rename {flash_examples => examples}/text_classification_multi_label.py (100%) rename {flash_examples => examples}/text_embedder.py (100%) rename {flash_examples => examples}/translation.py (100%) rename {flash_examples => examples}/video_classification.py (100%) rename {flash_examples => examples}/visualizations/pointcloud_detection.py (100%) rename {flash_examples => examples}/visualizations/pointcloud_segmentation.py (100%) rename {flash => src/flash}/__about__.py (100%) rename {flash => src/flash}/__init__.py (100%) rename {flash => src/flash}/__main__.py (100%) rename {flash => src/flash}/assets/example.wav (100%) rename {flash => src/flash}/assets/fish.jpg (100%) rename {flash => src/flash}/assets/road.png (100%) rename {flash => src/flash}/assets/starry_night.jpg (100%) rename {flash => src/flash}/audio/__init__.py (100%) rename {flash => src/flash}/audio/classification/__init__.py (100%) rename {flash => src/flash}/audio/classification/cli.py (100%) rename {flash => src/flash}/audio/classification/data.py (100%) rename {flash => src/flash}/audio/classification/input.py (100%) rename {flash => src/flash}/audio/classification/input_transform.py (100%) rename {flash => src/flash}/audio/speech_recognition/__init__.py (100%) rename {flash => src/flash}/audio/speech_recognition/backbone.py (100%) rename {flash => src/flash}/audio/speech_recognition/cli.py (100%) rename {flash => src/flash}/audio/speech_recognition/collate.py (100%) rename {flash => src/flash}/audio/speech_recognition/data.py (100%) rename {flash => src/flash}/audio/speech_recognition/input.py (100%) rename {flash => src/flash}/audio/speech_recognition/model.py (100%) rename {flash => src/flash}/audio/speech_recognition/output_transform.py (100%) rename {flash/core/data => src/flash/core}/__init__.py (100%) rename {flash => src/flash}/core/adapter.py (100%) rename {flash => src/flash}/core/classification.py (100%) rename {flash/core/data/io => src/flash/core/data}/__init__.py (100%) rename {flash => src/flash}/core/data/base_viz.py (100%) rename {flash => src/flash}/core/data/batch.py (100%) rename {flash => src/flash}/core/data/callback.py (100%) rename {flash => src/flash}/core/data/data_module.py (100%) rename {flash/core/data/utilities => src/flash/core/data/io}/__init__.py (100%) rename {flash => src/flash}/core/data/io/classification_input.py (100%) rename {flash => src/flash}/core/data/io/input.py (100%) rename {flash => src/flash}/core/data/io/input_transform.py (100%) rename {flash => src/flash}/core/data/io/output.py (100%) rename {flash => src/flash}/core/data/io/output_transform.py (100%) rename {flash => src/flash}/core/data/io/transform_predictions.py (100%) rename {flash => src/flash}/core/data/output.py (100%) rename {flash => src/flash}/core/data/properties.py (100%) rename {flash => src/flash}/core/data/splits.py (100%) rename {flash => src/flash}/core/data/transforms.py (100%) rename {flash/core/integrations => src/flash/core/data/utilities}/__init__.py (100%) rename {flash => src/flash}/core/data/utilities/classification.py (100%) rename {flash => src/flash}/core/data/utilities/collate.py (100%) rename {flash => src/flash}/core/data/utilities/data_frame.py (100%) rename {flash => src/flash}/core/data/utilities/loading.py (100%) rename {flash => src/flash}/core/data/utilities/paths.py (100%) rename {flash => src/flash}/core/data/utilities/samples.py (100%) rename {flash => src/flash}/core/data/utilities/sort.py (100%) rename {flash => src/flash}/core/data/utils.py (100%) rename {flash => src/flash}/core/finetuning.py (100%) rename {flash => src/flash}/core/heads.py (100%) rename {flash => src/flash}/core/hooks.py (100%) rename {flash/core/integrations/icevision => src/flash/core/integrations}/__init__.py (100%) rename {flash => src/flash}/core/integrations/fiftyone/__init__.py (100%) rename {flash => src/flash}/core/integrations/fiftyone/utils.py (100%) rename {flash/core/integrations/labelstudio => src/flash/core/integrations/icevision}/__init__.py (100%) rename {flash => src/flash}/core/integrations/icevision/adapter.py (100%) rename {flash => src/flash}/core/integrations/icevision/backbones.py (100%) rename {flash => src/flash}/core/integrations/icevision/data.py (100%) rename {flash => src/flash}/core/integrations/icevision/transforms.py (100%) rename {flash => src/flash}/core/integrations/icevision/wrappers.py (100%) rename {flash/core/integrations/pytorch_tabular => src/flash/core/integrations/labelstudio}/__init__.py (100%) rename {flash => src/flash}/core/integrations/labelstudio/input.py (100%) rename {flash => src/flash}/core/integrations/labelstudio/visualizer.py (100%) rename {flash => src/flash}/core/integrations/pytorch_forecasting/__init__.py (100%) rename {flash => src/flash}/core/integrations/pytorch_forecasting/adapter.py (100%) rename {flash => src/flash}/core/integrations/pytorch_forecasting/backbones.py (100%) rename {flash => src/flash}/core/integrations/pytorch_forecasting/transforms.py (100%) rename {flash/core/integrations/transformers => src/flash/core/integrations/pytorch_tabular}/__init__.py (100%) rename {flash => src/flash}/core/integrations/pytorch_tabular/adapter.py (100%) rename {flash => src/flash}/core/integrations/pytorch_tabular/backbones.py (100%) rename {flash/core/serve/dag => src/flash/core/integrations/transformers}/__init__.py (100%) rename {flash => src/flash}/core/integrations/transformers/collate.py (100%) rename {flash => src/flash}/core/model.py (100%) rename {flash => src/flash}/core/optimizers/__init__.py (100%) rename {flash => src/flash}/core/optimizers/lamb.py (100%) rename {flash => src/flash}/core/optimizers/lars.py (100%) rename {flash => src/flash}/core/optimizers/lr_scheduler.py (100%) rename {flash => src/flash}/core/optimizers/optimizers.py (100%) rename {flash => src/flash}/core/optimizers/schedulers.py (100%) rename {flash => src/flash}/core/registry.py (100%) rename {flash => src/flash}/core/regression.py (100%) rename {flash => src/flash}/core/serve/__init__.py (100%) rename {flash => src/flash}/core/serve/_compat/__init__.py (100%) rename {flash => src/flash}/core/serve/_compat/cached_property.py (100%) rename {flash => src/flash}/core/serve/component.py (100%) rename {flash => src/flash}/core/serve/composition.py (100%) rename {flash => src/flash}/core/serve/core.py (100%) rename {flash => src/flash}/core/serve/dag/NOTICE (100%) rename {flash/core/serve/interfaces => src/flash/core/serve/dag}/__init__.py (100%) rename {flash => src/flash}/core/serve/dag/optimization.py (100%) rename {flash => src/flash}/core/serve/dag/order.py (100%) rename {flash => src/flash}/core/serve/dag/rewrite.py (100%) rename {flash => src/flash}/core/serve/dag/task.py (100%) rename {flash => src/flash}/core/serve/dag/utils.py (100%) rename {flash => src/flash}/core/serve/dag/utils_test.py (100%) rename {flash => src/flash}/core/serve/dag/visualize.py (100%) rename {flash => src/flash}/core/serve/decorators.py (100%) rename {flash => src/flash}/core/serve/execution.py (100%) rename {flash => src/flash}/core/serve/flash_components.py (100%) rename {flash/core/serve/interfaces/templates => src/flash/core/serve/interfaces}/__init__.py (100%) rename {flash => src/flash}/core/serve/interfaces/http.py (100%) rename {flash => src/flash}/core/serve/interfaces/models.py (100%) rename {flash/core/utilities => src/flash/core/serve/interfaces/templates}/__init__.py (100%) rename {flash => src/flash}/core/serve/interfaces/templates/dag.html (100%) rename {flash => src/flash}/core/serve/server.py (100%) rename {flash => src/flash}/core/serve/types/__init__.py (100%) rename {flash => src/flash}/core/serve/types/base.py (100%) rename {flash => src/flash}/core/serve/types/bbox.py (100%) rename {flash => src/flash}/core/serve/types/image.py (100%) rename {flash => src/flash}/core/serve/types/label.py (100%) rename {flash => src/flash}/core/serve/types/number.py (100%) rename {flash => src/flash}/core/serve/types/repeated.py (100%) rename {flash => src/flash}/core/serve/types/table.py (100%) rename {flash => src/flash}/core/serve/types/text.py (100%) rename {flash => src/flash}/core/serve/utils.py (100%) rename {flash => src/flash}/core/trainer.py (100%) rename {flash/image/classification/integrations => src/flash/core/utilities}/__init__.py (100%) rename {flash => src/flash}/core/utilities/apply_func.py (100%) rename {flash => src/flash}/core/utilities/compatibility.py (100%) rename {flash => src/flash}/core/utilities/embedder.py (100%) rename {flash => src/flash}/core/utilities/flash_cli.py (100%) rename {flash => src/flash}/core/utilities/imports.py (100%) rename {flash => src/flash}/core/utilities/isinstance.py (100%) rename {flash => src/flash}/core/utilities/lightning_cli.py (100%) rename {flash => src/flash}/core/utilities/providers.py (100%) rename {flash => src/flash}/core/utilities/stability.py (100%) rename {flash => src/flash}/core/utilities/stages.py (100%) rename {flash => src/flash}/core/utilities/types.py (100%) rename {flash => src/flash}/core/utilities/url_error.py (100%) rename {flash => src/flash}/graph/__init__.py (100%) rename {flash => src/flash}/graph/backbones.py (100%) rename {flash => src/flash}/graph/classification/__init__.py (100%) rename {flash => src/flash}/graph/classification/cli.py (100%) rename {flash => src/flash}/graph/classification/data.py (100%) rename {flash => src/flash}/graph/classification/input.py (100%) rename {flash => src/flash}/graph/classification/input_transform.py (100%) rename {flash => src/flash}/graph/classification/model.py (100%) rename {flash => src/flash}/graph/collate.py (100%) rename {flash => src/flash}/graph/embedding/__init__.py (100%) rename {flash => src/flash}/graph/embedding/model.py (100%) rename {flash => src/flash}/image/__init__.py (100%) rename {flash => src/flash}/image/classification/__init__.py (100%) rename {flash => src/flash}/image/classification/adapters.py (100%) rename {flash => src/flash}/image/classification/backbones/__init__.py (100%) rename {flash => src/flash}/image/classification/backbones/clip.py (100%) rename {flash => src/flash}/image/classification/backbones/resnet.py (100%) rename {flash => src/flash}/image/classification/backbones/timm.py (100%) rename {flash => src/flash}/image/classification/backbones/torchvision.py (100%) rename {flash => src/flash}/image/classification/backbones/transformers.py (100%) rename {flash => src/flash}/image/classification/cli.py (100%) rename {flash => src/flash}/image/classification/data.py (100%) rename {flash => src/flash}/image/classification/heads.py (100%) rename {flash => src/flash}/image/classification/input.py (100%) rename {flash => src/flash}/image/classification/input_transform.py (100%) rename {flash/image/embedding/vissl => src/flash/image/classification/integrations}/__init__.py (100%) rename {flash => src/flash}/image/classification/integrations/baal/__init__.py (100%) rename {flash => src/flash}/image/classification/integrations/baal/data.py (100%) rename {flash => src/flash}/image/classification/integrations/baal/dropout.py (100%) rename {flash => src/flash}/image/classification/integrations/baal/loop.py (100%) rename {flash => src/flash}/image/classification/integrations/learn2learn.py (100%) rename {flash => src/flash}/image/classification/model.py (100%) rename {flash => src/flash}/image/data.py (100%) rename {flash => src/flash}/image/detection/__init__.py (100%) rename {flash => src/flash}/image/detection/backbones.py (100%) rename {flash => src/flash}/image/detection/cli.py (100%) rename {flash => src/flash}/image/detection/data.py (100%) rename {flash => src/flash}/image/detection/input.py (100%) rename {flash => src/flash}/image/detection/model.py (100%) rename {flash => src/flash}/image/detection/output.py (100%) rename {flash => src/flash}/image/embedding/__init__.py (100%) rename {flash => src/flash}/image/embedding/heads/__init__.py (100%) rename {flash => src/flash}/image/embedding/heads/vissl_heads.py (100%) rename {flash => src/flash}/image/embedding/losses/__init__.py (100%) rename {flash => src/flash}/image/embedding/losses/vissl_losses.py (100%) rename {flash => src/flash}/image/embedding/model.py (100%) rename {flash => src/flash}/image/embedding/strategies/__init__.py (100%) rename {flash => src/flash}/image/embedding/strategies/default.py (100%) rename {flash => src/flash}/image/embedding/strategies/vissl_strategies.py (100%) rename {flash => src/flash}/image/embedding/transforms/__init__.py (100%) rename {flash => src/flash}/image/embedding/transforms/vissl_transforms.py (100%) rename {flash/pointcloud/detection/open3d_ml => src/flash/image/embedding/vissl}/__init__.py (100%) rename {flash => src/flash}/image/embedding/vissl/adapter.py (100%) rename {flash => src/flash}/image/embedding/vissl/hooks.py (100%) rename {flash => src/flash}/image/embedding/vissl/transforms/__init__.py (100%) rename {flash => src/flash}/image/embedding/vissl/transforms/multicrop.py (100%) rename {flash => src/flash}/image/embedding/vissl/transforms/utilities.py (100%) rename {flash => src/flash}/image/face_detection/__init__.py (100%) rename {flash => src/flash}/image/face_detection/backbones/__init__.py (100%) rename {flash => src/flash}/image/face_detection/backbones/fastface_backbones.py (100%) rename {flash => src/flash}/image/face_detection/cli.py (100%) rename {flash => src/flash}/image/face_detection/data.py (100%) rename {flash => src/flash}/image/face_detection/input.py (100%) rename {flash => src/flash}/image/face_detection/input_transform.py (100%) rename {flash => src/flash}/image/face_detection/model.py (100%) rename {flash => src/flash}/image/face_detection/output_transform.py (100%) rename {flash => src/flash}/image/instance_segmentation/__init__.py (100%) rename {flash => src/flash}/image/instance_segmentation/backbones.py (100%) rename {flash => src/flash}/image/instance_segmentation/cli.py (100%) rename {flash => src/flash}/image/instance_segmentation/data.py (100%) rename {flash => src/flash}/image/instance_segmentation/model.py (100%) rename {flash => src/flash}/image/keypoint_detection/__init__.py (100%) rename {flash => src/flash}/image/keypoint_detection/backbones.py (100%) rename {flash => src/flash}/image/keypoint_detection/cli.py (100%) rename {flash => src/flash}/image/keypoint_detection/data.py (100%) rename {flash => src/flash}/image/keypoint_detection/input_transform.py (100%) rename {flash => src/flash}/image/keypoint_detection/model.py (100%) rename {flash => src/flash}/image/segmentation/__init__.py (100%) rename {flash => src/flash}/image/segmentation/backbones.py (100%) rename {flash => src/flash}/image/segmentation/cli.py (100%) rename {flash => src/flash}/image/segmentation/data.py (100%) rename {flash => src/flash}/image/segmentation/heads.py (100%) rename {flash => src/flash}/image/segmentation/input.py (100%) rename {flash => src/flash}/image/segmentation/input_transform.py (100%) rename {flash => src/flash}/image/segmentation/model.py (100%) rename {flash => src/flash}/image/segmentation/output.py (100%) rename {flash => src/flash}/image/segmentation/viz.py (100%) rename {flash => src/flash}/image/style_transfer/__init__.py (100%) rename {flash => src/flash}/image/style_transfer/backbones.py (100%) rename {flash => src/flash}/image/style_transfer/cli.py (100%) rename {flash => src/flash}/image/style_transfer/data.py (100%) rename {flash => src/flash}/image/style_transfer/input_transform.py (100%) rename {flash => src/flash}/image/style_transfer/model.py (100%) rename {flash => src/flash}/image/style_transfer/utils.py (100%) rename {flash => src/flash}/pointcloud/__init__.py (100%) rename {flash => src/flash}/pointcloud/detection/__init__.py (100%) rename {flash => src/flash}/pointcloud/detection/backbones.py (100%) rename {flash => src/flash}/pointcloud/detection/cli.py (100%) rename {flash => src/flash}/pointcloud/detection/data.py (100%) rename {flash => src/flash}/pointcloud/detection/datasets.py (100%) rename {flash => src/flash}/pointcloud/detection/input.py (100%) rename {flash => src/flash}/pointcloud/detection/model.py (100%) rename {flash/pointcloud/segmentation => src/flash/pointcloud/detection}/open3d_ml/__init__.py (100%) rename {flash => src/flash}/pointcloud/detection/open3d_ml/app.py (100%) rename {flash => src/flash}/pointcloud/detection/open3d_ml/backbones.py (100%) rename {flash => src/flash}/pointcloud/detection/open3d_ml/input.py (100%) rename {flash => src/flash}/pointcloud/segmentation/__init__.py (100%) rename {flash => src/flash}/pointcloud/segmentation/backbones.py (100%) rename {flash => src/flash}/pointcloud/segmentation/cli.py (100%) rename {flash => src/flash}/pointcloud/segmentation/data.py (100%) rename {flash => src/flash}/pointcloud/segmentation/datasets.py (100%) rename {flash => src/flash}/pointcloud/segmentation/input.py (100%) rename {flash => src/flash}/pointcloud/segmentation/model.py (100%) rename {flash/video/classification => src/flash/pointcloud/segmentation/open3d_ml}/__init__.py (100%) rename {flash => src/flash}/pointcloud/segmentation/open3d_ml/app.py (100%) rename {flash => src/flash}/pointcloud/segmentation/open3d_ml/backbones.py (100%) rename {flash => src/flash}/pointcloud/segmentation/open3d_ml/sequences_dataset.py (100%) rename {flash => src/flash}/tabular/__init__.py (100%) rename {flash => src/flash}/tabular/classification/__init__.py (100%) rename {flash => src/flash}/tabular/classification/cli.py (100%) rename {flash => src/flash}/tabular/classification/data.py (100%) rename {flash => src/flash}/tabular/classification/input.py (100%) rename {flash => src/flash}/tabular/classification/model.py (100%) rename {flash => src/flash}/tabular/classification/utils.py (100%) rename {flash => src/flash}/tabular/data.py (100%) rename {flash => src/flash}/tabular/forecasting/__init__.py (100%) rename {flash => src/flash}/tabular/forecasting/cli.py (100%) rename {flash => src/flash}/tabular/forecasting/data.py (100%) rename {flash => src/flash}/tabular/forecasting/input.py (100%) rename {flash => src/flash}/tabular/forecasting/model.py (100%) rename {flash => src/flash}/tabular/input.py (100%) rename {flash => src/flash}/tabular/regression/__init__.py (100%) rename {flash => src/flash}/tabular/regression/cli.py (100%) rename {flash => src/flash}/tabular/regression/data.py (100%) rename {flash => src/flash}/tabular/regression/input.py (100%) rename {flash => src/flash}/tabular/regression/model.py (100%) rename {flash => src/flash}/template/__init__.py (100%) rename {flash => src/flash}/template/classification/__init__.py (100%) rename {flash => src/flash}/template/classification/backbones.py (100%) rename {flash => src/flash}/template/classification/data.py (100%) rename {flash => src/flash}/template/classification/model.py (100%) rename {flash => src/flash}/text/__init__.py (100%) rename {flash => src/flash}/text/classification/__init__.py (100%) rename {flash => src/flash}/text/classification/adapters.py (100%) rename {flash => src/flash}/text/classification/backbones/__init__.py (100%) rename {flash => src/flash}/text/classification/backbones/clip.py (100%) rename {flash => src/flash}/text/classification/backbones/huggingface.py (100%) rename {flash => src/flash}/text/classification/cli.py (100%) rename {flash => src/flash}/text/classification/collate.py (100%) rename {flash => src/flash}/text/classification/data.py (100%) rename {flash => src/flash}/text/classification/input.py (100%) rename {flash => src/flash}/text/classification/model.py (100%) rename {flash => src/flash}/text/embedding/__init__.py (100%) rename {flash => src/flash}/text/embedding/backbones.py (100%) rename {flash => src/flash}/text/embedding/model.py (100%) rename {flash => src/flash}/text/input.py (100%) rename {flash => src/flash}/text/ort_callback.py (100%) rename {flash => src/flash}/text/question_answering/__init__.py (100%) rename {flash => src/flash}/text/question_answering/cli.py (100%) rename {flash => src/flash}/text/question_answering/collate.py (100%) rename {flash => src/flash}/text/question_answering/data.py (100%) rename {flash => src/flash}/text/question_answering/input.py (100%) rename {flash => src/flash}/text/question_answering/model.py (100%) rename {flash => src/flash}/text/question_answering/output_transform.py (100%) rename {flash => src/flash}/text/seq2seq/__init__.py (100%) rename {flash => src/flash}/text/seq2seq/core/__init__.py (100%) rename {flash => src/flash}/text/seq2seq/core/collate.py (100%) rename {flash => src/flash}/text/seq2seq/core/input.py (100%) rename {flash => src/flash}/text/seq2seq/core/model.py (100%) rename {flash => src/flash}/text/seq2seq/summarization/__init__.py (100%) rename {flash => src/flash}/text/seq2seq/summarization/cli.py (100%) rename {flash => src/flash}/text/seq2seq/summarization/data.py (100%) rename {flash => src/flash}/text/seq2seq/summarization/model.py (100%) rename {flash => src/flash}/text/seq2seq/translation/__init__.py (100%) rename {flash => src/flash}/text/seq2seq/translation/cli.py (100%) rename {flash => src/flash}/text/seq2seq/translation/data.py (100%) rename {flash => src/flash}/text/seq2seq/translation/model.py (100%) rename {flash => src/flash}/video/__init__.py (100%) rename {flash_examples/serve/image_classification => src/flash/video/classification}/__init__.py (100%) rename {flash => src/flash}/video/classification/cli.py (100%) rename {flash => src/flash}/video/classification/data.py (100%) rename {flash => src/flash}/video/classification/input.py (100%) rename {flash => src/flash}/video/classification/input_transform.py (100%) rename {flash => src/flash}/video/classification/model.py (100%) rename {flash => src/flash}/video/classification/utils.py (100%) diff --git a/.github/labeler.yml b/.github/labeler.yml index b190ac2113..77aa8c0001 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -3,23 +3,23 @@ documentation: - README.md examples: - - flash_examples/**/* + - examples/**/* data: - - flash/core/data/**/* + - src/flash/core/data/**/* task: - - flash/tabular/**/* - - flash/text/**/* - - flash/image/**/* - - flash/video/**/* + - src/flash/tabular/**/* + - src/flash/text/**/* + - src/flash/image/**/* + - src/flash/video/**/* tabular: - - flash/tabular/**/* + - src/flash/tabular/**/* text: - - flash/text/**/* + - src/flash/text/**/* vision: - - flash/image/**/* - - flash/video/**/* + - src/flash/image/**/* + - src/flash/video/**/* diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d6884eefff..5f5dfe999b 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -48,7 +48,7 @@ jobs: - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'audio' ] } # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 35 + timeout-minutes: 50 env: FREEZE_REQUIREMENTS: 1 @@ -89,17 +89,6 @@ jobs: if: matrix.requires != 'latest' id: times - - name: Get pip cache dir - id: pip-cache - run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - - name: Cache pip - uses: actions/cache@v3 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-${{ matrix.python-version }}-td${{ steps.times.outputs.period }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }} - restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-td${{ steps.times.outputs.period }}-${{ join(matrix.topic,'-') }}-${{ matrix.requires }}-pip- - - name: Install graph test dependencies if: contains( matrix.topic , 'graph' ) run: | @@ -148,7 +137,9 @@ jobs: FLASH_TEST_TOPIC: ${{ join(matrix.topic,',') }} FIFTYONE_DO_NOT_TRACK: true run: | - coverage run --source flash -m pytest flash tests --reruns 3 --reruns-delay 2 -v + pip list + # FixMe: include doctests for src/ + coverage run --source flash -m pytest tests/ -v --reruns 3 --reruns-delay 2 - name: Statistics if: success() diff --git a/.gitignore b/.gitignore index 41cfb04542..33480ab9fb 100644 --- a/.gitignore +++ b/.gitignore @@ -158,10 +158,10 @@ movie_posters CameraRGB CameraSeg jigsaw_toxic_comments -flash_examples/serve/tabular_classification/data +examples/serve/tabular_classification/data logs/cache/* -flash_examples/data -flash_examples/checkpoints +examples/data +examples/checkpoints timit/ urban8k_images/ __MACOSX diff --git a/.readthedocs.yml b/.readthedocs.yml index 566c41fd49..3700ebb138 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -31,3 +31,5 @@ build: python: install: - requirements: requirements/docs.txt + - method: pip + path: . diff --git a/MANIFEST.in b/MANIFEST.in index 8aad84f3de..5a819748e9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,11 +5,11 @@ recursive-exclude __pycache__ *.py[cod] *.orig # Include the README and CHANGELOG include *.md -recursive-include flash *.md -recursive-include flash *.py +recursive-include src *.md +recursive-include src *.py # Include assets -recursive-include flash/assets *.jpg *.png +recursive-include src/flash/assets * # Include the license file include LICENSE @@ -18,15 +18,6 @@ exclude *.sh exclude *.toml exclude *.svg -# exclude tests from package -recursive-exclude tests * -recursive-exclude site * -exclude tests - -# Exclude the documentation files -recursive-exclude docs * -exclude docs - # Include the Requirements include requirements/*.txt include requirements.txt @@ -36,6 +27,7 @@ exclude *.yml prune .git prune .github +prune docs prune notebook* prune temp* prune test* diff --git a/README.md b/README.md index c485d01d1b..da21a904a2 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ predictions = trainer.predict(model, dm) Training strategies are PyTorch SOTA Training Recipes which can be utilized with a given task. -Check out this [example](https://github.com/Lightning-AI/lightning-flash/blob/master/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py) where the `ImageClassifier` supports 4 [Meta Learning Algorithms](https://lilianweng.github.io/lil-log/2018/11/30/meta-learning.html) from [Learn2Learn](https://github.com/learnables/learn2learn). +Check out this [example](https://github.com/Lightning-AI/lightning-flash/blob/master/examples/integrations/learn2learn/image_classification_imagenette_mini.py) where the `ImageClassifier` supports 4 [Meta Learning Algorithms](https://lilianweng.github.io/lil-log/2018/11/30/meta-learning.html) from [Learn2Learn](https://github.com/learnables/learn2learn). This is particularly useful if you use this model in production and want to make sure the model adapts quickly to its new environment with minimal labelled data. ```py diff --git a/docs/source/conf.py b/docs/source/conf.py index d574567d77..7fe77a3a47 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -36,7 +36,7 @@ def _load_py_module(fname, pkg="flash"): - spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname)) + spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, "src", pkg, fname)) py = module_from_spec(spec) spec.loader.exec_module(py) return py diff --git a/docs/source/general/production.rst b/docs/source/general/production.rst index 59e07b74c4..7804b4b2af 100644 --- a/docs/source/general/production.rst +++ b/docs/source/general/production.rst @@ -10,7 +10,7 @@ Flash Serve makes model deployment simple. Server Side ^^^^^^^^^^^ -.. literalinclude:: ../../../flash_examples/serve/semantic_segmentation/inference_server.py +.. literalinclude:: ../../../examples/serve/semantic_segmentation/inference_server.py :language: python :lines: 14- @@ -18,7 +18,7 @@ Server Side Client Side ^^^^^^^^^^^ -.. literalinclude:: ../../../flash_examples/serve/semantic_segmentation/client.py +.. literalinclude:: ../../../examples/serve/semantic_segmentation/client.py :language: python :lines: 14- diff --git a/docs/source/general/serve.rst b/docs/source/general/serve.rst index 5ddab0c914..b94eae4d6d 100644 --- a/docs/source/general/serve.rst +++ b/docs/source/general/serve.rst @@ -41,7 +41,7 @@ Example In this tutorial, we will serve a Resnet18 from the `PyTorchVision library `_ in 3 steps. -The entire tutorial can be found under ``flash_examples/serve/generic``. +The entire tutorial can be found under ``examples/serve/generic``. Introduction ============ diff --git a/docs/source/integrations/baal.rst b/docs/source/integrations/baal.rst index 163af500d7..8088729f84 100644 --- a/docs/source/integrations/baal.rst +++ b/docs/source/integrations/baal.rst @@ -29,6 +29,6 @@ The most uncertain samples will be labelled by the human to accelerate the model With its integration within Flash, the Active Learning process is simpler than ever before. -.. literalinclude:: ../../../flash_examples/integrations/baal/image_classification_active_learning.py +.. literalinclude:: ../../../examples/integrations/baal/image_classification_active_learning.py :language: python :lines: 14- diff --git a/docs/source/integrations/fiftyone.rst b/docs/source/integrations/fiftyone.rst index 1c64173272..2b1fda86d1 100644 --- a/docs/source/integrations/fiftyone.rst +++ b/docs/source/integrations/fiftyone.rst @@ -57,7 +57,7 @@ dictionaries containing :ref:`FiftyOne Label ` objects and filepaths, which is exactly the output of the FiftyOne outputs when the ``return_filepath=True`` option is specified. -.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification.py +.. literalinclude:: ../../../examples/integrations/fiftyone/image_classification.py :language: python :lines: 14- @@ -94,7 +94,7 @@ method allows you to load your FiftyOne datasets directly into a :class:`~flash.core.data.data_module.DataModule` to be used for training, testing, or inference. -.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +.. literalinclude:: ../../../examples/integrations/fiftyone/image_classification_fiftyone_datasets.py :language: python :lines: 14- @@ -109,7 +109,7 @@ FiftyOne provides the methods for powerful workflows like clustering, similarity search, pre-annotation, and more in only a few lines of code. -.. literalinclude:: ../../../flash_examples/integrations/fiftyone/image_embedding.py +.. literalinclude:: ../../../examples/integrations/fiftyone/image_embedding.py :language: python :lines: 14- diff --git a/docs/source/integrations/learn2learn.rst b/docs/source/integrations/learn2learn.rst index 6435704be3..f16ccffe00 100644 --- a/docs/source/integrations/learn2learn.rst +++ b/docs/source/integrations/learn2learn.rst @@ -72,7 +72,7 @@ Once done, the users are left to play the hyper-parameters associated with the m Here is an example using `miniImageNet dataset `_ containing 100 classes divided into 64 training, 16 validation, and 20 test classes. -.. literalinclude:: ../../../flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +.. literalinclude:: ../../../examples/integrations/learn2learn/image_classification_imagenette_mini.py :language: python :lines: 15- diff --git a/docs/source/integrations/pytorch_forecasting.rst b/docs/source/integrations/pytorch_forecasting.rst index dbec7cabe8..49776a2dab 100644 --- a/docs/source/integrations/pytorch_forecasting.rst +++ b/docs/source/integrations/pytorch_forecasting.rst @@ -13,7 +13,7 @@ With these, you can train your model and perform inference using Flash but still Here's an example, plotting the predictions and interpretation analysis from the NBeats model trained in the :ref:`tabular_forecasting` documentation: -.. literalinclude:: ../../../flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py +.. literalinclude:: ../../../examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py :language: python :lines: 14- diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst index d4fa45953a..ffc9dfc4d1 100644 --- a/docs/source/reference/audio_classification.rst +++ b/docs/source/reference/audio_classification.rst @@ -73,7 +73,7 @@ We then use the trained :class:`~flash.image.classification.model.ImageClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/audio_classification.py +.. literalinclude:: ../../../examples/audio_classification.py :language: python :lines: 14- diff --git a/docs/source/reference/graph_classification.rst b/docs/source/reference/graph_classification.rst index 758d3ceb69..92fde7042a 100644 --- a/docs/source/reference/graph_classification.rst +++ b/docs/source/reference/graph_classification.rst @@ -34,7 +34,7 @@ Next, we use the trained :class:`~flash.graph.classification.model.GraphClassifi Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/graph_classification.py +.. literalinclude:: ../../../examples/graph_classification.py :language: python :lines: 14- diff --git a/docs/source/reference/graph_embedder.rst b/docs/source/reference/graph_embedder.rst index 17c8ad6ba0..c57323fb8a 100644 --- a/docs/source/reference/graph_embedder.rst +++ b/docs/source/reference/graph_embedder.rst @@ -23,7 +23,7 @@ Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/graph_embedder.py +.. literalinclude:: ../../../examples/graph_embedder.py :language: python :lines: 14 diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 20a830c8dd..098f03b9ef 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -56,7 +56,7 @@ We then use the trained :class:`~flash.image.classification.model.ImageClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/image_classification.py +.. literalinclude:: ../../../examples/image_classification.py :language: python :lines: 14- @@ -181,12 +181,12 @@ The :class:`~flash.image.classification.model.ImageClassifier` is servable. This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: -.. literalinclude:: ../../../flash_examples/serve/image_classification/inference_server.py +.. literalinclude:: ../../../examples/serve/image_classification/inference_server.py :language: python :lines: 14- You can now perform inference from your client like this: -.. literalinclude:: ../../../flash_examples/serve/image_classification/client.py +.. literalinclude:: ../../../examples/serve/image_classification/client.py :language: python :lines: 14- diff --git a/docs/source/reference/image_classification_multi_label.rst b/docs/source/reference/image_classification_multi_label.rst index fec3c1e5d3..fc2b42889f 100644 --- a/docs/source/reference/image_classification_multi_label.rst +++ b/docs/source/reference/image_classification_multi_label.rst @@ -50,7 +50,7 @@ We then use the trained :class:`~flash.image.classification.model.ImageClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/image_classification_multi_label.py +.. literalinclude:: ../../../examples/image_classification_multi_label.py :language: python :lines: 14- diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 0229364644..dc2e66ac20 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -40,7 +40,7 @@ Next, we configure the :class:`~flash.image.embedding.model.ImageEmbedder` task Finally, we construct a :class:`~flash.core.trainer.Trainer` and call ``fit()``. Here's the full example: -.. literalinclude:: ../../../flash_examples/image_embedder.py +.. literalinclude:: ../../../examples/image_embedder.py :language: python :lines: 14- diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst index a786c994c8..9d9ba82218 100644 --- a/docs/source/reference/instance_segmentation.rst +++ b/docs/source/reference/instance_segmentation.rst @@ -31,7 +31,7 @@ We then use the trained :class:`~flash.image.instance_segmentation.model.Instanc Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/instance_segmentation.py +.. literalinclude:: ../../../examples/instance_segmentation.py :language: python :lines: 14- diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst index 10202a453b..692451220a 100644 --- a/docs/source/reference/keypoint_detection.rst +++ b/docs/source/reference/keypoint_detection.rst @@ -31,7 +31,7 @@ We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDe Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/keypoint_detection.py +.. literalinclude:: ../../../examples/keypoint_detection.py :language: python :lines: 14- diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index b9b3b5cfe3..4aab8e3d5e 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -51,7 +51,7 @@ We then use the trained :class:`~flash.image.detection.model.ObjectDetector` for Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/object_detection.py +.. literalinclude:: ../../../examples/object_detection.py :language: python :lines: 14- @@ -126,12 +126,12 @@ The :class:`~flash.image.detection.model.ObjectDetector` is servable. This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: -.. literalinclude:: ../../../flash_examples/serve/object_detection/inference_server.py +.. literalinclude:: ../../../examples/serve/object_detection/inference_server.py :language: python :lines: 14- You can now perform inference from your client like this: -.. literalinclude:: ../../../flash_examples/serve/object_detection/client.py +.. literalinclude:: ../../../examples/serve/object_detection/client.py :language: python :lines: 14- diff --git a/docs/source/reference/pointcloud_object_detection.rst b/docs/source/reference/pointcloud_object_detection.rst index edeb2c9aa3..ebb85302b3 100644 --- a/docs/source/reference/pointcloud_object_detection.rst +++ b/docs/source/reference/pointcloud_object_detection.rst @@ -80,7 +80,7 @@ We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDet Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/pointcloud_detection.py +.. literalinclude:: ../../../examples/pointcloud_detection.py :language: python :lines: 14- diff --git a/docs/source/reference/pointcloud_segmentation.rst b/docs/source/reference/pointcloud_segmentation.rst index 6b4fe25bb4..6b441f2211 100644 --- a/docs/source/reference/pointcloud_segmentation.rst +++ b/docs/source/reference/pointcloud_segmentation.rst @@ -71,7 +71,7 @@ We then use the trained ``PointCloudSegmentation`` for inference. Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/pointcloud_segmentation.py +.. literalinclude:: ../../../examples/pointcloud_segmentation.py :language: python :lines: 14- diff --git a/docs/source/reference/question_answering.rst b/docs/source/reference/question_answering.rst index d4686a290a..29b710181b 100644 --- a/docs/source/reference/question_answering.rst +++ b/docs/source/reference/question_answering.rst @@ -60,7 +60,7 @@ Next, we use the trained :class:`~flash.text.question_answering.model.QuestionAn Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/question_answering.py +.. literalinclude:: ../../../examples/question_answering.py :language: python :lines: 14- diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 749c153702..1f5738629c 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -45,7 +45,7 @@ We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmenta Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/semantic_segmentation.py +.. literalinclude:: ../../../examples/semantic_segmentation.py :language: python :lines: 14- @@ -81,12 +81,12 @@ The :class:`~flash.image.segmentation.model.SemanticSegmentation` task is servab This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: -.. literalinclude:: ../../../flash_examples/serve/semantic_segmentation/inference_server.py +.. literalinclude:: ../../../examples/serve/semantic_segmentation/inference_server.py :language: python :lines: 14- You can now perform inference from your client like this: -.. literalinclude:: ../../../flash_examples/serve/semantic_segmentation/client.py +.. literalinclude:: ../../../examples/serve/semantic_segmentation/client.py :language: python :lines: 14- diff --git a/docs/source/reference/speech_recognition.rst b/docs/source/reference/speech_recognition.rst index d527355805..f9a8753f48 100644 --- a/docs/source/reference/speech_recognition.rst +++ b/docs/source/reference/speech_recognition.rst @@ -49,7 +49,7 @@ The backbone can be any Wav2Vec model from `HuggingFace transformers -.. literalinclude:: ../../../flash_examples/template.py +.. literalinclude:: ../../../examples/template.py :language: python :lines: 14- diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 0677d9e59f..a372755745 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -49,7 +49,7 @@ Next, we use the trained :class:`~flash.text.classification.model.TextClassifier Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/text_classification.py +.. literalinclude:: ../../../examples/text_classification.py :language: python :lines: 14- @@ -84,13 +84,13 @@ The :class:`~flash.text.classification.model.TextClassifier` is servable. This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: -.. literalinclude:: ../../../flash_examples/serve/text_classification/inference_server.py +.. literalinclude:: ../../../examples/serve/text_classification/inference_server.py :language: python :lines: 14- You can now perform inference from your client like this: -.. literalinclude:: ../../../flash_examples/serve/text_classification/client.py +.. literalinclude:: ../../../examples/serve/text_classification/client.py :language: python :lines: 14- diff --git a/docs/source/reference/text_classification_multi_label.rst b/docs/source/reference/text_classification_multi_label.rst index a8317fd9ae..408f3b24b7 100644 --- a/docs/source/reference/text_classification_multi_label.rst +++ b/docs/source/reference/text_classification_multi_label.rst @@ -47,7 +47,7 @@ Next, we use the trained :class:`~flash.text.classification.model.TextClassifier Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/text_classification_multi_label.py +.. literalinclude:: ../../../examples/text_classification_multi_label.py :language: python :lines: 14- diff --git a/docs/source/reference/text_embedder.rst b/docs/source/reference/text_embedder.rst index 5483988e74..38d33f572e 100644 --- a/docs/source/reference/text_embedder.rst +++ b/docs/source/reference/text_embedder.rst @@ -30,7 +30,7 @@ Next, we create our :class:`~flash.text.embedding.model.TextEmbedder` with a pre Finally, we create a :class:`~flash.core.trainer.Trainer` and generate sentence embeddings. Here's the full example: -.. literalinclude:: ../../../flash_examples/text_embedder.py +.. literalinclude:: ../../../examples/text_embedder.py :language: python :lines: 14- diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst index 171258d24b..3eb40df4ca 100644 --- a/docs/source/reference/translation.rst +++ b/docs/source/reference/translation.rst @@ -49,7 +49,7 @@ Next, we use the trained :class:`~flash.text.seq2seq.translation.model.Translati Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/translation.py +.. literalinclude:: ../../../examples/translation.py :language: python :lines: 14- @@ -84,13 +84,13 @@ The :class:`~flash.text.seq2seq.translation.model.TranslationTask` is servable. This means you can call ``.serve`` to serve your :class:`~flash.core.model.Task`. Here's an example: -.. literalinclude:: ../../../flash_examples/serve/translation/inference_server.py +.. literalinclude:: ../../../examples/serve/translation/inference_server.py :language: python :lines: 14- You can now perform inference from your client like this: -.. literalinclude:: ../../../flash_examples/serve/translation/client.py +.. literalinclude:: ../../../examples/serve/translation/client.py :language: python :lines: 14- diff --git a/docs/source/reference/video_classification.rst b/docs/source/reference/video_classification.rst index fe477f89e3..5ae7697a1f 100644 --- a/docs/source/reference/video_classification.rst +++ b/docs/source/reference/video_classification.rst @@ -58,7 +58,7 @@ We then use the trained :class:`~flash.video.classification.model.VideoClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../flash_examples/video_classification.py +.. literalinclude:: ../../../examples/video_classification.py :language: python :lines: 14- diff --git a/docs/source/template/backbones.rst b/docs/source/template/backbones.rst index bcbac896a2..278e001ec6 100644 --- a/docs/source/template/backbones.rst +++ b/docs/source/template/backbones.rst @@ -24,19 +24,19 @@ You also need to provide ``name`` and ``namespace`` of the backbone. The standard for *namespace* is ``data_type/task_type``, so for an image classification task the namespace will be ``image/classification``. Here's the code: -.. literalinclude:: ../../../flash/template/classification/backbones.py +.. literalinclude:: ../../../src/flash/template/classification/backbones.py :language: python :pyobject: load_mlp_128 Here's another example with a slightly more complex model: -.. literalinclude:: ../../../flash/template/classification/backbones.py +.. literalinclude:: ../../../src/flash/template/classification/backbones.py :language: python :pyobject: load_mlp_128_256 Here's a another example, which adds ``DINO`` pretrained model from PyTorch Hub to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/classification/backbones/transformers.py `_: -.. literalinclude:: ../../../flash/image/classification/backbones/transformers.py +.. literalinclude:: ../../../src/flash/image/classification/backbones/transformers.py :language: python :pyobject: dino_vitb16 diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 6b696cacf3..dee4450542 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -34,14 +34,14 @@ In this ``Input``, we'll also set the ``num_features`` attribute so that we can Here's the code for our ``TemplateNumpyClassificationInput.load_data`` method: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateNumpyClassificationInput.load_data and here's the code for the ``TemplateNumpyClassificationInput.load_sample`` method: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateNumpyClassificationInput.load_sample @@ -58,7 +58,7 @@ We perform two additional steps here to improve the user experience: Here's the code for the ``TemplateSKLearnClassificationInput.load_data`` method: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateSKLearnClassificationInput.load_data @@ -67,7 +67,7 @@ We can customize the behaviour of our :meth:`~flash.core.data.io.input.Input.loa For our ``TemplateSKLearnClassificationInput``, we don't want to provide any targets to the model when predicting. We can implement ``predict_load_data`` like this: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateSKLearnClassificationInput.predict_load_data @@ -83,14 +83,14 @@ Defining the standard transforms (typically at least a ``per_sample_transform`` For our ``TemplateInputTransform``, we'll just configure a ``per_sample_transform``. Let's first define a to_tensor transform as a ``staticmethod``: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateInputTransform.to_tensor Now in our ``per_sample_transform`` hook, we return the transform: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateInputTransform.per_sample_transform @@ -122,7 +122,7 @@ Since we provided a :attr:`~flash.core.data.io.input.InputFormat.NUMPY` :class:` If you've defined a fully custom :class:`~flash.core.data.io.input.Input` (like our ``TemplateSKLearnClassificationInput``), then you will need to write a ``from_*`` method for each. Here's the ``from_sklearn`` method for our ``TemplateData``: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateData.from_sklearn @@ -131,7 +131,7 @@ The final step is to implement the ``num_features`` property for our ``TemplateD This is just a convenience for the user that finds the ``num_features`` attribute on any of the data sets and returns it. Here's the code: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateData.num_features @@ -148,13 +148,13 @@ This is extremely useful for debugging purposes, allowing users to view their da Here's the code for our ``TemplateVisualization`` which just prints the data: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :pyobject: TemplateVisualization We can configure our custom visualization in the ``TemplateData`` using :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` like this: -.. literalinclude:: ../../../flash/template/classification/data.py +.. literalinclude:: ../../../src/flash/template/classification/data.py :language: python :dedent: 4 :pyobject: TemplateData.configure_data_fetcher @@ -166,7 +166,7 @@ OutputTransform You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. As an example, here's the :class:`~image.segmentation.model.SemanticSegmentationOutputTransform` which decodes tokenized model outputs: -.. literalinclude:: ../../../flash/image/segmentation/model.py +.. literalinclude:: ../../../src/flash/image/segmentation/model.py :language: python :pyobject: SemanticSegmentationOutputTransform @@ -176,7 +176,7 @@ You should use this approach if your postprocessing depends on the state of the For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the :attr:`~flash.core.data.io.input.DataKeys.METADATA`. Here's an example from the :class:`~flash.image.data.ImageInput`: -.. literalinclude:: ../../../flash/image/data.py +.. literalinclude:: ../../../src/flash/image/data.py :language: python :dedent: 4 :pyobject: ImageInput.load_sample @@ -184,7 +184,7 @@ Here's an example from the :class:`~flash.image.data.ImageInput`: The :attr:`~flash.core.data.io.input.DataKeys.METADATA` can now be referenced in your :class:`~flash.core.data.io.output_transform.OutputTransform`. For example, here's the code for the ``per_sample_transform`` method of the :class:`~flash.image.segmentation.model.SemanticSegmentationOutputTransform`: -.. literalinclude:: ../../../flash/image/segmentation/model.py +.. literalinclude:: ../../../src/flash/image/segmentation/model.py :language: python :dedent: 4 :pyobject: SemanticSegmentationOutputTransform.per_sample_transform diff --git a/docs/source/template/examples.rst b/docs/source/template/examples.rst index 08a2d92103..bd993b79c4 100644 --- a/docs/source/template/examples.rst +++ b/docs/source/template/examples.rst @@ -5,7 +5,7 @@ The Example *********** Now you've implemented your task, it's time to add an example showing how cool it is! -We usually provide one example in `flash_examples/ `_. +We usually provide one example in `examples/ `_. You can base these off of our ``template.py`` examples. The example should: @@ -19,9 +19,9 @@ The example should: #. save the checkpoint For our template example we don't have a pretrained backbone, so we can just call :meth:`~flash.core.trainer.Trainer.fit` rather than :meth:`~flash.core.trainer.Trainer.finetune`. -Here's the full example (`flash_examples/template.py `_): +Here's the full example (`examples/template.py `_): -.. literalinclude:: ../../../flash_examples/template.py +.. literalinclude:: ../../../examples/template.py :language: python :lines: 14- diff --git a/docs/source/template/optional.rst b/docs/source/template/optional.rst index 32f24152be..046073300b 100644 --- a/docs/source/template/optional.rst +++ b/docs/source/template/optional.rst @@ -10,7 +10,7 @@ Organize your transforms in transforms.py It can be useful to define your :class:`~flash.core.data.io.input_transform.InputTransform` in an ``input_transform.py`` file. Here's an example from `image/classification/input_transform.py `_: -.. literalinclude:: ../../../flash/image/classification/input_transform.py +.. literalinclude:: ../../../src/flash/image/classification/input_transform.py :language: python :pyobject: ImageClassificationInputTransform @@ -24,13 +24,13 @@ If you want to support different use cases that require different prediction for Some good examples are in `flash/core/classification.py `_. Here's the :class:`~flash.core.classification.ClassesOutput` :class:`~flash.core.data.io.output.Output`: -.. literalinclude:: ../../../flash/core/classification.py +.. literalinclude:: ../../../src/flash/core/classification.py :language: python :pyobject: ClassesOutput Alternatively, here's the :class:`~flash.core.classification.LogitsOutput` :class:`~flash.core.data.io.output.Output`: -.. literalinclude:: ../../../flash/core/classification.py +.. literalinclude:: ../../../src/flash/core/classification.py :language: python :pyobject: LogitsOutput diff --git a/docs/source/template/task.rst b/docs/source/template/task.rst index 6dc1496d70..d5c6d9583e 100644 --- a/docs/source/template/task.rst +++ b/docs/source/template/task.rst @@ -31,7 +31,7 @@ In the :meth:`~flash.core.model.Task.__init__`, you will need to configure defau You will also need to create the backbone from the registry and create the model head. Here's the code: -.. literalinclude:: ../../../flash/template/classification/model.py +.. literalinclude:: ../../../src/flash/template/classification/model.py :language: python :dedent: 4 :pyobject: TemplateSKLearnClassifier.__init__ @@ -46,7 +46,7 @@ The default ``{train,val,test,predict}_step`` implementations in :class:`~flash. In our template example, we just extract the input and target from the input mapping and forward them to the ``super`` methods. Here's the code for the ``training_step``: -.. literalinclude:: ../../../flash/template/classification/model.py +.. literalinclude:: ../../../src/flash/template/classification/model.py :language: python :dedent: 4 :pyobject: TemplateSKLearnClassifier.training_step @@ -54,7 +54,7 @@ Here's the code for the ``training_step``: We use the same code for the ``validation_step`` and ``test_step``. For ``predict_step`` we don't need the targets, so our code looks like this: -.. literalinclude:: ../../../flash/template/classification/model.py +.. literalinclude:: ../../../src/flash/template/classification/model.py :language: python :dedent: 4 :pyobject: TemplateSKLearnClassifier.predict_step @@ -63,7 +63,7 @@ For ``predict_step`` we don't need the targets, so our code looks like this: Finally, we use our backbone and head in a custom forward pass: -.. literalinclude:: ../../../flash/template/classification/model.py +.. literalinclude:: ../../../src/flash/template/classification/model.py :language: python :dedent: 4 :pyobject: TemplateSKLearnClassifier.forward diff --git a/flash_examples/audio_classification.py b/examples/audio_classification.py similarity index 100% rename from flash_examples/audio_classification.py rename to examples/audio_classification.py diff --git a/flash_examples/face_detection.py b/examples/face_detection.py similarity index 100% rename from flash_examples/face_detection.py rename to examples/face_detection.py diff --git a/flash_examples/graph_classification.py b/examples/graph_classification.py similarity index 100% rename from flash_examples/graph_classification.py rename to examples/graph_classification.py diff --git a/flash_examples/graph_embedder.py b/examples/graph_embedder.py similarity index 100% rename from flash_examples/graph_embedder.py rename to examples/graph_embedder.py diff --git a/flash_examples/image_classification.py b/examples/image_classification.py similarity index 100% rename from flash_examples/image_classification.py rename to examples/image_classification.py diff --git a/flash_examples/image_classification_multi_label.py b/examples/image_classification_multi_label.py similarity index 100% rename from flash_examples/image_classification_multi_label.py rename to examples/image_classification_multi_label.py diff --git a/flash_examples/image_embedder.py b/examples/image_embedder.py similarity index 100% rename from flash_examples/image_embedder.py rename to examples/image_embedder.py diff --git a/flash_examples/instance_segmentation.py b/examples/instance_segmentation.py similarity index 100% rename from flash_examples/instance_segmentation.py rename to examples/instance_segmentation.py diff --git a/flash_examples/integrations/baal/image_classification_active_learning.py b/examples/integrations/baal/image_classification_active_learning.py similarity index 100% rename from flash_examples/integrations/baal/image_classification_active_learning.py rename to examples/integrations/baal/image_classification_active_learning.py diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/examples/integrations/fiftyone/image_classification.py similarity index 100% rename from flash_examples/integrations/fiftyone/image_classification.py rename to examples/integrations/fiftyone/image_classification.py diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/examples/integrations/fiftyone/image_classification_fiftyone_datasets.py similarity index 100% rename from flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py rename to examples/integrations/fiftyone/image_classification_fiftyone_datasets.py diff --git a/flash_examples/integrations/fiftyone/image_embedding.py b/examples/integrations/fiftyone/image_embedding.py similarity index 100% rename from flash_examples/integrations/fiftyone/image_embedding.py rename to examples/integrations/fiftyone/image_embedding.py diff --git a/flash_examples/integrations/fiftyone/object_detection.py b/examples/integrations/fiftyone/object_detection.py similarity index 100% rename from flash_examples/integrations/fiftyone/object_detection.py rename to examples/integrations/fiftyone/object_detection.py diff --git a/flash_examples/integrations/labelstudio/image_classification.py b/examples/integrations/labelstudio/image_classification.py similarity index 100% rename from flash_examples/integrations/labelstudio/image_classification.py rename to examples/integrations/labelstudio/image_classification.py diff --git a/flash_examples/integrations/labelstudio/text_classification.py b/examples/integrations/labelstudio/text_classification.py similarity index 100% rename from flash_examples/integrations/labelstudio/text_classification.py rename to examples/integrations/labelstudio/text_classification.py diff --git a/flash_examples/integrations/labelstudio/video_classification.py b/examples/integrations/labelstudio/video_classification.py similarity index 100% rename from flash_examples/integrations/labelstudio/video_classification.py rename to examples/integrations/labelstudio/video_classification.py diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/examples/integrations/learn2learn/image_classification_imagenette_mini.py similarity index 100% rename from flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py rename to examples/integrations/learn2learn/image_classification_imagenette_mini.py diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py similarity index 100% rename from flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py rename to examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py diff --git a/flash_examples/keypoint_detection.py b/examples/keypoint_detection.py similarity index 100% rename from flash_examples/keypoint_detection.py rename to examples/keypoint_detection.py diff --git a/flash_examples/object_detection.py b/examples/object_detection.py similarity index 100% rename from flash_examples/object_detection.py rename to examples/object_detection.py diff --git a/flash_examples/pointcloud_detection.py b/examples/pointcloud_detection.py similarity index 100% rename from flash_examples/pointcloud_detection.py rename to examples/pointcloud_detection.py diff --git a/flash_examples/pointcloud_segmentation.py b/examples/pointcloud_segmentation.py similarity index 100% rename from flash_examples/pointcloud_segmentation.py rename to examples/pointcloud_segmentation.py diff --git a/flash_examples/question_answering.py b/examples/question_answering.py similarity index 100% rename from flash_examples/question_answering.py rename to examples/question_answering.py diff --git a/flash_examples/semantic_segmentation.py b/examples/semantic_segmentation.py similarity index 100% rename from flash_examples/semantic_segmentation.py rename to examples/semantic_segmentation.py diff --git a/flash_examples/serve/generic/boston_prediction/client.py b/examples/serve/generic/boston_prediction/client.py similarity index 100% rename from flash_examples/serve/generic/boston_prediction/client.py rename to examples/serve/generic/boston_prediction/client.py diff --git a/flash_examples/serve/generic/boston_prediction/inference_server.py b/examples/serve/generic/boston_prediction/inference_server.py similarity index 100% rename from flash_examples/serve/generic/boston_prediction/inference_server.py rename to examples/serve/generic/boston_prediction/inference_server.py diff --git a/flash_examples/serve/generic/boston_prediction/requirements.txt b/examples/serve/generic/boston_prediction/requirements.txt similarity index 100% rename from flash_examples/serve/generic/boston_prediction/requirements.txt rename to examples/serve/generic/boston_prediction/requirements.txt diff --git a/flash_examples/serve/generic/detection/classes.txt b/examples/serve/generic/detection/classes.txt similarity index 100% rename from flash_examples/serve/generic/detection/classes.txt rename to examples/serve/generic/detection/classes.txt diff --git a/flash_examples/serve/generic/detection/client.py b/examples/serve/generic/detection/client.py similarity index 100% rename from flash_examples/serve/generic/detection/client.py rename to examples/serve/generic/detection/client.py diff --git a/flash_examples/serve/generic/detection/inference.py b/examples/serve/generic/detection/inference.py similarity index 100% rename from flash_examples/serve/generic/detection/inference.py rename to examples/serve/generic/detection/inference.py diff --git a/flash_examples/serve/generic/detection/input.jpg b/examples/serve/generic/detection/input.jpg similarity index 100% rename from flash_examples/serve/generic/detection/input.jpg rename to examples/serve/generic/detection/input.jpg diff --git a/flash/core/__init__.py b/examples/serve/image_classification/__init__.py similarity index 100% rename from flash/core/__init__.py rename to examples/serve/image_classification/__init__.py diff --git a/flash_examples/serve/image_classification/client.py b/examples/serve/image_classification/client.py similarity index 100% rename from flash_examples/serve/image_classification/client.py rename to examples/serve/image_classification/client.py diff --git a/flash_examples/serve/image_classification/inference_server.py b/examples/serve/image_classification/inference_server.py similarity index 100% rename from flash_examples/serve/image_classification/inference_server.py rename to examples/serve/image_classification/inference_server.py diff --git a/flash_examples/serve/object_detection/client.py b/examples/serve/object_detection/client.py similarity index 100% rename from flash_examples/serve/object_detection/client.py rename to examples/serve/object_detection/client.py diff --git a/flash_examples/serve/object_detection/inference_server.py b/examples/serve/object_detection/inference_server.py similarity index 100% rename from flash_examples/serve/object_detection/inference_server.py rename to examples/serve/object_detection/inference_server.py diff --git a/flash_examples/serve/semantic_segmentation/client.py b/examples/serve/semantic_segmentation/client.py similarity index 100% rename from flash_examples/serve/semantic_segmentation/client.py rename to examples/serve/semantic_segmentation/client.py diff --git a/flash_examples/serve/semantic_segmentation/inference_server.py b/examples/serve/semantic_segmentation/inference_server.py similarity index 100% rename from flash_examples/serve/semantic_segmentation/inference_server.py rename to examples/serve/semantic_segmentation/inference_server.py diff --git a/flash_examples/serve/speech_recognition/client.py b/examples/serve/speech_recognition/client.py similarity index 100% rename from flash_examples/serve/speech_recognition/client.py rename to examples/serve/speech_recognition/client.py diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/examples/serve/speech_recognition/inference_server.py similarity index 100% rename from flash_examples/serve/speech_recognition/inference_server.py rename to examples/serve/speech_recognition/inference_server.py diff --git a/flash_examples/serve/summarization/client.py b/examples/serve/summarization/client.py similarity index 100% rename from flash_examples/serve/summarization/client.py rename to examples/serve/summarization/client.py diff --git a/flash_examples/serve/summarization/inference_server.py b/examples/serve/summarization/inference_server.py similarity index 100% rename from flash_examples/serve/summarization/inference_server.py rename to examples/serve/summarization/inference_server.py diff --git a/flash_examples/serve/tabular_classification/client.py b/examples/serve/tabular_classification/client.py similarity index 100% rename from flash_examples/serve/tabular_classification/client.py rename to examples/serve/tabular_classification/client.py diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/examples/serve/tabular_classification/inference_server.py similarity index 100% rename from flash_examples/serve/tabular_classification/inference_server.py rename to examples/serve/tabular_classification/inference_server.py diff --git a/flash_examples/serve/text_classification/client.py b/examples/serve/text_classification/client.py similarity index 100% rename from flash_examples/serve/text_classification/client.py rename to examples/serve/text_classification/client.py diff --git a/flash_examples/serve/text_classification/inference_server.py b/examples/serve/text_classification/inference_server.py similarity index 100% rename from flash_examples/serve/text_classification/inference_server.py rename to examples/serve/text_classification/inference_server.py diff --git a/flash_examples/serve/translation/client.py b/examples/serve/translation/client.py similarity index 100% rename from flash_examples/serve/translation/client.py rename to examples/serve/translation/client.py diff --git a/flash_examples/serve/translation/inference_server.py b/examples/serve/translation/inference_server.py similarity index 100% rename from flash_examples/serve/translation/inference_server.py rename to examples/serve/translation/inference_server.py diff --git a/flash_examples/speech_recognition.py b/examples/speech_recognition.py similarity index 100% rename from flash_examples/speech_recognition.py rename to examples/speech_recognition.py diff --git a/flash_examples/style_transfer.py b/examples/style_transfer.py similarity index 100% rename from flash_examples/style_transfer.py rename to examples/style_transfer.py diff --git a/flash_examples/summarization.py b/examples/summarization.py similarity index 100% rename from flash_examples/summarization.py rename to examples/summarization.py diff --git a/flash_examples/tabular_classification.py b/examples/tabular_classification.py similarity index 100% rename from flash_examples/tabular_classification.py rename to examples/tabular_classification.py diff --git a/flash_examples/tabular_forecasting.py b/examples/tabular_forecasting.py similarity index 100% rename from flash_examples/tabular_forecasting.py rename to examples/tabular_forecasting.py diff --git a/flash_examples/tabular_regression.py b/examples/tabular_regression.py similarity index 100% rename from flash_examples/tabular_regression.py rename to examples/tabular_regression.py diff --git a/flash_examples/template.py b/examples/template.py similarity index 100% rename from flash_examples/template.py rename to examples/template.py diff --git a/flash_examples/text_classification.py b/examples/text_classification.py similarity index 100% rename from flash_examples/text_classification.py rename to examples/text_classification.py diff --git a/flash_examples/text_classification_multi_label.py b/examples/text_classification_multi_label.py similarity index 100% rename from flash_examples/text_classification_multi_label.py rename to examples/text_classification_multi_label.py diff --git a/flash_examples/text_embedder.py b/examples/text_embedder.py similarity index 100% rename from flash_examples/text_embedder.py rename to examples/text_embedder.py diff --git a/flash_examples/translation.py b/examples/translation.py similarity index 100% rename from flash_examples/translation.py rename to examples/translation.py diff --git a/flash_examples/video_classification.py b/examples/video_classification.py similarity index 100% rename from flash_examples/video_classification.py rename to examples/video_classification.py diff --git a/flash_examples/visualizations/pointcloud_detection.py b/examples/visualizations/pointcloud_detection.py similarity index 100% rename from flash_examples/visualizations/pointcloud_detection.py rename to examples/visualizations/pointcloud_detection.py diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/examples/visualizations/pointcloud_segmentation.py similarity index 100% rename from flash_examples/visualizations/pointcloud_segmentation.py rename to examples/visualizations/pointcloud_segmentation.py diff --git a/requirements.txt b/requirements.txt index 3c42fa44e1..e7da88a101 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup + packaging setuptools<=59.5.0 # Prevent install bug with tensorboard numpy<1.24 # strict - freeze for using np.long torch>=1.7.1 -torchmetrics>0.5.1, <0.11.0 -pytorch-lightning>=1.3.6, <1.9.0 +torchmetrics>0.5.1, <0.11.0 # strict +pytorch-lightning>=1.3.6, <1.9.0 # strict pyDeprecate pandas>=1.1.0, <=1.5.2 jsonargparse[signatures]>=3.17.0, <=4.9.0 diff --git a/requirements/test.txt b/requirements/test.txt index b1344ce6a7..fd36809355 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,8 +6,7 @@ pytest>=5.0, <7.0 pytest-doctestplus>=0.9.0 pytest-rerunfailures>=10.0 pytest-forked +pytest-mock scikit-learn -pytest_mock -qpth <0.0.14 torch_optimizer diff --git a/setup.cfg b/setup.cfg index d55a737e65..6ab4af21b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,6 @@ norecursedirs = build doctest_plus = enabled addopts = - --strict --durations=0 --color=yes @@ -26,7 +25,7 @@ exclude_lines = [isort] known_first_party = flash - flash_examples + examples tests line_length = 120 order_by_type = False @@ -75,7 +74,7 @@ ignore = # Typing tests is low priority, but enabling type checking on the # untyped test functions (using `--check-untyped-defs`) is still # high-value because it helps test the typing. -files = flash, flash_examples, tests +files = flash, examples, tests pretty = True show_error_codes = True disallow_untyped_defs = True diff --git a/setup.py b/setup.py index bfca41797b..c7ddb9e911 100644 --- a/setup.py +++ b/setup.py @@ -108,7 +108,7 @@ def _load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: boo def _load_py_module(fname, pkg="flash"): - spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, pkg, fname)) + spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_ROOT, "src", pkg, fname)) py = module_from_spec(spec) spec.loader.exec_module(py) return py @@ -162,7 +162,8 @@ def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict: url=about.__homepage__, download_url="https://github.com/Lightning-AI/lightning-flash", license=about.__license__, - packages=find_packages(exclude=["tests", "tests.*"]), + package_dir={"": "src"}, + packages=find_packages(where="src"), long_description=_load_readme_description(_PATH_ROOT, homepage=about.__homepage__, ver=about.__version__), long_description_content_type="text/markdown", include_package_data=True, diff --git a/flash/__about__.py b/src/flash/__about__.py similarity index 100% rename from flash/__about__.py rename to src/flash/__about__.py diff --git a/flash/__init__.py b/src/flash/__init__.py similarity index 100% rename from flash/__init__.py rename to src/flash/__init__.py diff --git a/flash/__main__.py b/src/flash/__main__.py similarity index 100% rename from flash/__main__.py rename to src/flash/__main__.py diff --git a/flash/assets/example.wav b/src/flash/assets/example.wav similarity index 100% rename from flash/assets/example.wav rename to src/flash/assets/example.wav diff --git a/flash/assets/fish.jpg b/src/flash/assets/fish.jpg similarity index 100% rename from flash/assets/fish.jpg rename to src/flash/assets/fish.jpg diff --git a/flash/assets/road.png b/src/flash/assets/road.png similarity index 100% rename from flash/assets/road.png rename to src/flash/assets/road.png diff --git a/flash/assets/starry_night.jpg b/src/flash/assets/starry_night.jpg similarity index 100% rename from flash/assets/starry_night.jpg rename to src/flash/assets/starry_night.jpg diff --git a/flash/audio/__init__.py b/src/flash/audio/__init__.py similarity index 100% rename from flash/audio/__init__.py rename to src/flash/audio/__init__.py diff --git a/flash/audio/classification/__init__.py b/src/flash/audio/classification/__init__.py similarity index 100% rename from flash/audio/classification/__init__.py rename to src/flash/audio/classification/__init__.py diff --git a/flash/audio/classification/cli.py b/src/flash/audio/classification/cli.py similarity index 100% rename from flash/audio/classification/cli.py rename to src/flash/audio/classification/cli.py diff --git a/flash/audio/classification/data.py b/src/flash/audio/classification/data.py similarity index 100% rename from flash/audio/classification/data.py rename to src/flash/audio/classification/data.py diff --git a/flash/audio/classification/input.py b/src/flash/audio/classification/input.py similarity index 100% rename from flash/audio/classification/input.py rename to src/flash/audio/classification/input.py diff --git a/flash/audio/classification/input_transform.py b/src/flash/audio/classification/input_transform.py similarity index 100% rename from flash/audio/classification/input_transform.py rename to src/flash/audio/classification/input_transform.py diff --git a/flash/audio/speech_recognition/__init__.py b/src/flash/audio/speech_recognition/__init__.py similarity index 100% rename from flash/audio/speech_recognition/__init__.py rename to src/flash/audio/speech_recognition/__init__.py diff --git a/flash/audio/speech_recognition/backbone.py b/src/flash/audio/speech_recognition/backbone.py similarity index 100% rename from flash/audio/speech_recognition/backbone.py rename to src/flash/audio/speech_recognition/backbone.py diff --git a/flash/audio/speech_recognition/cli.py b/src/flash/audio/speech_recognition/cli.py similarity index 100% rename from flash/audio/speech_recognition/cli.py rename to src/flash/audio/speech_recognition/cli.py diff --git a/flash/audio/speech_recognition/collate.py b/src/flash/audio/speech_recognition/collate.py similarity index 100% rename from flash/audio/speech_recognition/collate.py rename to src/flash/audio/speech_recognition/collate.py diff --git a/flash/audio/speech_recognition/data.py b/src/flash/audio/speech_recognition/data.py similarity index 100% rename from flash/audio/speech_recognition/data.py rename to src/flash/audio/speech_recognition/data.py diff --git a/flash/audio/speech_recognition/input.py b/src/flash/audio/speech_recognition/input.py similarity index 100% rename from flash/audio/speech_recognition/input.py rename to src/flash/audio/speech_recognition/input.py diff --git a/flash/audio/speech_recognition/model.py b/src/flash/audio/speech_recognition/model.py similarity index 100% rename from flash/audio/speech_recognition/model.py rename to src/flash/audio/speech_recognition/model.py diff --git a/flash/audio/speech_recognition/output_transform.py b/src/flash/audio/speech_recognition/output_transform.py similarity index 100% rename from flash/audio/speech_recognition/output_transform.py rename to src/flash/audio/speech_recognition/output_transform.py diff --git a/flash/core/data/__init__.py b/src/flash/core/__init__.py similarity index 100% rename from flash/core/data/__init__.py rename to src/flash/core/__init__.py diff --git a/flash/core/adapter.py b/src/flash/core/adapter.py similarity index 100% rename from flash/core/adapter.py rename to src/flash/core/adapter.py diff --git a/flash/core/classification.py b/src/flash/core/classification.py similarity index 100% rename from flash/core/classification.py rename to src/flash/core/classification.py diff --git a/flash/core/data/io/__init__.py b/src/flash/core/data/__init__.py similarity index 100% rename from flash/core/data/io/__init__.py rename to src/flash/core/data/__init__.py diff --git a/flash/core/data/base_viz.py b/src/flash/core/data/base_viz.py similarity index 100% rename from flash/core/data/base_viz.py rename to src/flash/core/data/base_viz.py diff --git a/flash/core/data/batch.py b/src/flash/core/data/batch.py similarity index 100% rename from flash/core/data/batch.py rename to src/flash/core/data/batch.py diff --git a/flash/core/data/callback.py b/src/flash/core/data/callback.py similarity index 100% rename from flash/core/data/callback.py rename to src/flash/core/data/callback.py diff --git a/flash/core/data/data_module.py b/src/flash/core/data/data_module.py similarity index 100% rename from flash/core/data/data_module.py rename to src/flash/core/data/data_module.py diff --git a/flash/core/data/utilities/__init__.py b/src/flash/core/data/io/__init__.py similarity index 100% rename from flash/core/data/utilities/__init__.py rename to src/flash/core/data/io/__init__.py diff --git a/flash/core/data/io/classification_input.py b/src/flash/core/data/io/classification_input.py similarity index 100% rename from flash/core/data/io/classification_input.py rename to src/flash/core/data/io/classification_input.py diff --git a/flash/core/data/io/input.py b/src/flash/core/data/io/input.py similarity index 100% rename from flash/core/data/io/input.py rename to src/flash/core/data/io/input.py diff --git a/flash/core/data/io/input_transform.py b/src/flash/core/data/io/input_transform.py similarity index 100% rename from flash/core/data/io/input_transform.py rename to src/flash/core/data/io/input_transform.py diff --git a/flash/core/data/io/output.py b/src/flash/core/data/io/output.py similarity index 100% rename from flash/core/data/io/output.py rename to src/flash/core/data/io/output.py diff --git a/flash/core/data/io/output_transform.py b/src/flash/core/data/io/output_transform.py similarity index 100% rename from flash/core/data/io/output_transform.py rename to src/flash/core/data/io/output_transform.py diff --git a/flash/core/data/io/transform_predictions.py b/src/flash/core/data/io/transform_predictions.py similarity index 100% rename from flash/core/data/io/transform_predictions.py rename to src/flash/core/data/io/transform_predictions.py diff --git a/flash/core/data/output.py b/src/flash/core/data/output.py similarity index 100% rename from flash/core/data/output.py rename to src/flash/core/data/output.py diff --git a/flash/core/data/properties.py b/src/flash/core/data/properties.py similarity index 100% rename from flash/core/data/properties.py rename to src/flash/core/data/properties.py diff --git a/flash/core/data/splits.py b/src/flash/core/data/splits.py similarity index 100% rename from flash/core/data/splits.py rename to src/flash/core/data/splits.py diff --git a/flash/core/data/transforms.py b/src/flash/core/data/transforms.py similarity index 100% rename from flash/core/data/transforms.py rename to src/flash/core/data/transforms.py diff --git a/flash/core/integrations/__init__.py b/src/flash/core/data/utilities/__init__.py similarity index 100% rename from flash/core/integrations/__init__.py rename to src/flash/core/data/utilities/__init__.py diff --git a/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py similarity index 100% rename from flash/core/data/utilities/classification.py rename to src/flash/core/data/utilities/classification.py diff --git a/flash/core/data/utilities/collate.py b/src/flash/core/data/utilities/collate.py similarity index 100% rename from flash/core/data/utilities/collate.py rename to src/flash/core/data/utilities/collate.py diff --git a/flash/core/data/utilities/data_frame.py b/src/flash/core/data/utilities/data_frame.py similarity index 100% rename from flash/core/data/utilities/data_frame.py rename to src/flash/core/data/utilities/data_frame.py diff --git a/flash/core/data/utilities/loading.py b/src/flash/core/data/utilities/loading.py similarity index 100% rename from flash/core/data/utilities/loading.py rename to src/flash/core/data/utilities/loading.py diff --git a/flash/core/data/utilities/paths.py b/src/flash/core/data/utilities/paths.py similarity index 100% rename from flash/core/data/utilities/paths.py rename to src/flash/core/data/utilities/paths.py diff --git a/flash/core/data/utilities/samples.py b/src/flash/core/data/utilities/samples.py similarity index 100% rename from flash/core/data/utilities/samples.py rename to src/flash/core/data/utilities/samples.py diff --git a/flash/core/data/utilities/sort.py b/src/flash/core/data/utilities/sort.py similarity index 100% rename from flash/core/data/utilities/sort.py rename to src/flash/core/data/utilities/sort.py diff --git a/flash/core/data/utils.py b/src/flash/core/data/utils.py similarity index 100% rename from flash/core/data/utils.py rename to src/flash/core/data/utils.py diff --git a/flash/core/finetuning.py b/src/flash/core/finetuning.py similarity index 100% rename from flash/core/finetuning.py rename to src/flash/core/finetuning.py diff --git a/flash/core/heads.py b/src/flash/core/heads.py similarity index 100% rename from flash/core/heads.py rename to src/flash/core/heads.py diff --git a/flash/core/hooks.py b/src/flash/core/hooks.py similarity index 100% rename from flash/core/hooks.py rename to src/flash/core/hooks.py diff --git a/flash/core/integrations/icevision/__init__.py b/src/flash/core/integrations/__init__.py similarity index 100% rename from flash/core/integrations/icevision/__init__.py rename to src/flash/core/integrations/__init__.py diff --git a/flash/core/integrations/fiftyone/__init__.py b/src/flash/core/integrations/fiftyone/__init__.py similarity index 100% rename from flash/core/integrations/fiftyone/__init__.py rename to src/flash/core/integrations/fiftyone/__init__.py diff --git a/flash/core/integrations/fiftyone/utils.py b/src/flash/core/integrations/fiftyone/utils.py similarity index 100% rename from flash/core/integrations/fiftyone/utils.py rename to src/flash/core/integrations/fiftyone/utils.py diff --git a/flash/core/integrations/labelstudio/__init__.py b/src/flash/core/integrations/icevision/__init__.py similarity index 100% rename from flash/core/integrations/labelstudio/__init__.py rename to src/flash/core/integrations/icevision/__init__.py diff --git a/flash/core/integrations/icevision/adapter.py b/src/flash/core/integrations/icevision/adapter.py similarity index 100% rename from flash/core/integrations/icevision/adapter.py rename to src/flash/core/integrations/icevision/adapter.py diff --git a/flash/core/integrations/icevision/backbones.py b/src/flash/core/integrations/icevision/backbones.py similarity index 100% rename from flash/core/integrations/icevision/backbones.py rename to src/flash/core/integrations/icevision/backbones.py diff --git a/flash/core/integrations/icevision/data.py b/src/flash/core/integrations/icevision/data.py similarity index 100% rename from flash/core/integrations/icevision/data.py rename to src/flash/core/integrations/icevision/data.py diff --git a/flash/core/integrations/icevision/transforms.py b/src/flash/core/integrations/icevision/transforms.py similarity index 100% rename from flash/core/integrations/icevision/transforms.py rename to src/flash/core/integrations/icevision/transforms.py diff --git a/flash/core/integrations/icevision/wrappers.py b/src/flash/core/integrations/icevision/wrappers.py similarity index 100% rename from flash/core/integrations/icevision/wrappers.py rename to src/flash/core/integrations/icevision/wrappers.py diff --git a/flash/core/integrations/pytorch_tabular/__init__.py b/src/flash/core/integrations/labelstudio/__init__.py similarity index 100% rename from flash/core/integrations/pytorch_tabular/__init__.py rename to src/flash/core/integrations/labelstudio/__init__.py diff --git a/flash/core/integrations/labelstudio/input.py b/src/flash/core/integrations/labelstudio/input.py similarity index 100% rename from flash/core/integrations/labelstudio/input.py rename to src/flash/core/integrations/labelstudio/input.py diff --git a/flash/core/integrations/labelstudio/visualizer.py b/src/flash/core/integrations/labelstudio/visualizer.py similarity index 100% rename from flash/core/integrations/labelstudio/visualizer.py rename to src/flash/core/integrations/labelstudio/visualizer.py diff --git a/flash/core/integrations/pytorch_forecasting/__init__.py b/src/flash/core/integrations/pytorch_forecasting/__init__.py similarity index 100% rename from flash/core/integrations/pytorch_forecasting/__init__.py rename to src/flash/core/integrations/pytorch_forecasting/__init__.py diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/src/flash/core/integrations/pytorch_forecasting/adapter.py similarity index 100% rename from flash/core/integrations/pytorch_forecasting/adapter.py rename to src/flash/core/integrations/pytorch_forecasting/adapter.py diff --git a/flash/core/integrations/pytorch_forecasting/backbones.py b/src/flash/core/integrations/pytorch_forecasting/backbones.py similarity index 100% rename from flash/core/integrations/pytorch_forecasting/backbones.py rename to src/flash/core/integrations/pytorch_forecasting/backbones.py diff --git a/flash/core/integrations/pytorch_forecasting/transforms.py b/src/flash/core/integrations/pytorch_forecasting/transforms.py similarity index 100% rename from flash/core/integrations/pytorch_forecasting/transforms.py rename to src/flash/core/integrations/pytorch_forecasting/transforms.py diff --git a/flash/core/integrations/transformers/__init__.py b/src/flash/core/integrations/pytorch_tabular/__init__.py similarity index 100% rename from flash/core/integrations/transformers/__init__.py rename to src/flash/core/integrations/pytorch_tabular/__init__.py diff --git a/flash/core/integrations/pytorch_tabular/adapter.py b/src/flash/core/integrations/pytorch_tabular/adapter.py similarity index 100% rename from flash/core/integrations/pytorch_tabular/adapter.py rename to src/flash/core/integrations/pytorch_tabular/adapter.py diff --git a/flash/core/integrations/pytorch_tabular/backbones.py b/src/flash/core/integrations/pytorch_tabular/backbones.py similarity index 100% rename from flash/core/integrations/pytorch_tabular/backbones.py rename to src/flash/core/integrations/pytorch_tabular/backbones.py diff --git a/flash/core/serve/dag/__init__.py b/src/flash/core/integrations/transformers/__init__.py similarity index 100% rename from flash/core/serve/dag/__init__.py rename to src/flash/core/integrations/transformers/__init__.py diff --git a/flash/core/integrations/transformers/collate.py b/src/flash/core/integrations/transformers/collate.py similarity index 100% rename from flash/core/integrations/transformers/collate.py rename to src/flash/core/integrations/transformers/collate.py diff --git a/flash/core/model.py b/src/flash/core/model.py similarity index 100% rename from flash/core/model.py rename to src/flash/core/model.py diff --git a/flash/core/optimizers/__init__.py b/src/flash/core/optimizers/__init__.py similarity index 100% rename from flash/core/optimizers/__init__.py rename to src/flash/core/optimizers/__init__.py diff --git a/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py similarity index 100% rename from flash/core/optimizers/lamb.py rename to src/flash/core/optimizers/lamb.py diff --git a/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py similarity index 100% rename from flash/core/optimizers/lars.py rename to src/flash/core/optimizers/lars.py diff --git a/flash/core/optimizers/lr_scheduler.py b/src/flash/core/optimizers/lr_scheduler.py similarity index 100% rename from flash/core/optimizers/lr_scheduler.py rename to src/flash/core/optimizers/lr_scheduler.py diff --git a/flash/core/optimizers/optimizers.py b/src/flash/core/optimizers/optimizers.py similarity index 100% rename from flash/core/optimizers/optimizers.py rename to src/flash/core/optimizers/optimizers.py diff --git a/flash/core/optimizers/schedulers.py b/src/flash/core/optimizers/schedulers.py similarity index 100% rename from flash/core/optimizers/schedulers.py rename to src/flash/core/optimizers/schedulers.py diff --git a/flash/core/registry.py b/src/flash/core/registry.py similarity index 100% rename from flash/core/registry.py rename to src/flash/core/registry.py diff --git a/flash/core/regression.py b/src/flash/core/regression.py similarity index 100% rename from flash/core/regression.py rename to src/flash/core/regression.py diff --git a/flash/core/serve/__init__.py b/src/flash/core/serve/__init__.py similarity index 100% rename from flash/core/serve/__init__.py rename to src/flash/core/serve/__init__.py diff --git a/flash/core/serve/_compat/__init__.py b/src/flash/core/serve/_compat/__init__.py similarity index 100% rename from flash/core/serve/_compat/__init__.py rename to src/flash/core/serve/_compat/__init__.py diff --git a/flash/core/serve/_compat/cached_property.py b/src/flash/core/serve/_compat/cached_property.py similarity index 100% rename from flash/core/serve/_compat/cached_property.py rename to src/flash/core/serve/_compat/cached_property.py diff --git a/flash/core/serve/component.py b/src/flash/core/serve/component.py similarity index 100% rename from flash/core/serve/component.py rename to src/flash/core/serve/component.py diff --git a/flash/core/serve/composition.py b/src/flash/core/serve/composition.py similarity index 100% rename from flash/core/serve/composition.py rename to src/flash/core/serve/composition.py diff --git a/flash/core/serve/core.py b/src/flash/core/serve/core.py similarity index 100% rename from flash/core/serve/core.py rename to src/flash/core/serve/core.py diff --git a/flash/core/serve/dag/NOTICE b/src/flash/core/serve/dag/NOTICE similarity index 100% rename from flash/core/serve/dag/NOTICE rename to src/flash/core/serve/dag/NOTICE diff --git a/flash/core/serve/interfaces/__init__.py b/src/flash/core/serve/dag/__init__.py similarity index 100% rename from flash/core/serve/interfaces/__init__.py rename to src/flash/core/serve/dag/__init__.py diff --git a/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py similarity index 100% rename from flash/core/serve/dag/optimization.py rename to src/flash/core/serve/dag/optimization.py diff --git a/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py similarity index 100% rename from flash/core/serve/dag/order.py rename to src/flash/core/serve/dag/order.py diff --git a/flash/core/serve/dag/rewrite.py b/src/flash/core/serve/dag/rewrite.py similarity index 100% rename from flash/core/serve/dag/rewrite.py rename to src/flash/core/serve/dag/rewrite.py diff --git a/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py similarity index 100% rename from flash/core/serve/dag/task.py rename to src/flash/core/serve/dag/task.py diff --git a/flash/core/serve/dag/utils.py b/src/flash/core/serve/dag/utils.py similarity index 100% rename from flash/core/serve/dag/utils.py rename to src/flash/core/serve/dag/utils.py diff --git a/flash/core/serve/dag/utils_test.py b/src/flash/core/serve/dag/utils_test.py similarity index 100% rename from flash/core/serve/dag/utils_test.py rename to src/flash/core/serve/dag/utils_test.py diff --git a/flash/core/serve/dag/visualize.py b/src/flash/core/serve/dag/visualize.py similarity index 100% rename from flash/core/serve/dag/visualize.py rename to src/flash/core/serve/dag/visualize.py diff --git a/flash/core/serve/decorators.py b/src/flash/core/serve/decorators.py similarity index 100% rename from flash/core/serve/decorators.py rename to src/flash/core/serve/decorators.py diff --git a/flash/core/serve/execution.py b/src/flash/core/serve/execution.py similarity index 100% rename from flash/core/serve/execution.py rename to src/flash/core/serve/execution.py diff --git a/flash/core/serve/flash_components.py b/src/flash/core/serve/flash_components.py similarity index 100% rename from flash/core/serve/flash_components.py rename to src/flash/core/serve/flash_components.py diff --git a/flash/core/serve/interfaces/templates/__init__.py b/src/flash/core/serve/interfaces/__init__.py similarity index 100% rename from flash/core/serve/interfaces/templates/__init__.py rename to src/flash/core/serve/interfaces/__init__.py diff --git a/flash/core/serve/interfaces/http.py b/src/flash/core/serve/interfaces/http.py similarity index 100% rename from flash/core/serve/interfaces/http.py rename to src/flash/core/serve/interfaces/http.py diff --git a/flash/core/serve/interfaces/models.py b/src/flash/core/serve/interfaces/models.py similarity index 100% rename from flash/core/serve/interfaces/models.py rename to src/flash/core/serve/interfaces/models.py diff --git a/flash/core/utilities/__init__.py b/src/flash/core/serve/interfaces/templates/__init__.py similarity index 100% rename from flash/core/utilities/__init__.py rename to src/flash/core/serve/interfaces/templates/__init__.py diff --git a/flash/core/serve/interfaces/templates/dag.html b/src/flash/core/serve/interfaces/templates/dag.html similarity index 100% rename from flash/core/serve/interfaces/templates/dag.html rename to src/flash/core/serve/interfaces/templates/dag.html diff --git a/flash/core/serve/server.py b/src/flash/core/serve/server.py similarity index 100% rename from flash/core/serve/server.py rename to src/flash/core/serve/server.py diff --git a/flash/core/serve/types/__init__.py b/src/flash/core/serve/types/__init__.py similarity index 100% rename from flash/core/serve/types/__init__.py rename to src/flash/core/serve/types/__init__.py diff --git a/flash/core/serve/types/base.py b/src/flash/core/serve/types/base.py similarity index 100% rename from flash/core/serve/types/base.py rename to src/flash/core/serve/types/base.py diff --git a/flash/core/serve/types/bbox.py b/src/flash/core/serve/types/bbox.py similarity index 100% rename from flash/core/serve/types/bbox.py rename to src/flash/core/serve/types/bbox.py diff --git a/flash/core/serve/types/image.py b/src/flash/core/serve/types/image.py similarity index 100% rename from flash/core/serve/types/image.py rename to src/flash/core/serve/types/image.py diff --git a/flash/core/serve/types/label.py b/src/flash/core/serve/types/label.py similarity index 100% rename from flash/core/serve/types/label.py rename to src/flash/core/serve/types/label.py diff --git a/flash/core/serve/types/number.py b/src/flash/core/serve/types/number.py similarity index 100% rename from flash/core/serve/types/number.py rename to src/flash/core/serve/types/number.py diff --git a/flash/core/serve/types/repeated.py b/src/flash/core/serve/types/repeated.py similarity index 100% rename from flash/core/serve/types/repeated.py rename to src/flash/core/serve/types/repeated.py diff --git a/flash/core/serve/types/table.py b/src/flash/core/serve/types/table.py similarity index 100% rename from flash/core/serve/types/table.py rename to src/flash/core/serve/types/table.py diff --git a/flash/core/serve/types/text.py b/src/flash/core/serve/types/text.py similarity index 100% rename from flash/core/serve/types/text.py rename to src/flash/core/serve/types/text.py diff --git a/flash/core/serve/utils.py b/src/flash/core/serve/utils.py similarity index 100% rename from flash/core/serve/utils.py rename to src/flash/core/serve/utils.py diff --git a/flash/core/trainer.py b/src/flash/core/trainer.py similarity index 100% rename from flash/core/trainer.py rename to src/flash/core/trainer.py diff --git a/flash/image/classification/integrations/__init__.py b/src/flash/core/utilities/__init__.py similarity index 100% rename from flash/image/classification/integrations/__init__.py rename to src/flash/core/utilities/__init__.py diff --git a/flash/core/utilities/apply_func.py b/src/flash/core/utilities/apply_func.py similarity index 100% rename from flash/core/utilities/apply_func.py rename to src/flash/core/utilities/apply_func.py diff --git a/flash/core/utilities/compatibility.py b/src/flash/core/utilities/compatibility.py similarity index 100% rename from flash/core/utilities/compatibility.py rename to src/flash/core/utilities/compatibility.py diff --git a/flash/core/utilities/embedder.py b/src/flash/core/utilities/embedder.py similarity index 100% rename from flash/core/utilities/embedder.py rename to src/flash/core/utilities/embedder.py diff --git a/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py similarity index 100% rename from flash/core/utilities/flash_cli.py rename to src/flash/core/utilities/flash_cli.py diff --git a/flash/core/utilities/imports.py b/src/flash/core/utilities/imports.py similarity index 100% rename from flash/core/utilities/imports.py rename to src/flash/core/utilities/imports.py diff --git a/flash/core/utilities/isinstance.py b/src/flash/core/utilities/isinstance.py similarity index 100% rename from flash/core/utilities/isinstance.py rename to src/flash/core/utilities/isinstance.py diff --git a/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py similarity index 100% rename from flash/core/utilities/lightning_cli.py rename to src/flash/core/utilities/lightning_cli.py diff --git a/flash/core/utilities/providers.py b/src/flash/core/utilities/providers.py similarity index 100% rename from flash/core/utilities/providers.py rename to src/flash/core/utilities/providers.py diff --git a/flash/core/utilities/stability.py b/src/flash/core/utilities/stability.py similarity index 100% rename from flash/core/utilities/stability.py rename to src/flash/core/utilities/stability.py diff --git a/flash/core/utilities/stages.py b/src/flash/core/utilities/stages.py similarity index 100% rename from flash/core/utilities/stages.py rename to src/flash/core/utilities/stages.py diff --git a/flash/core/utilities/types.py b/src/flash/core/utilities/types.py similarity index 100% rename from flash/core/utilities/types.py rename to src/flash/core/utilities/types.py diff --git a/flash/core/utilities/url_error.py b/src/flash/core/utilities/url_error.py similarity index 100% rename from flash/core/utilities/url_error.py rename to src/flash/core/utilities/url_error.py diff --git a/flash/graph/__init__.py b/src/flash/graph/__init__.py similarity index 100% rename from flash/graph/__init__.py rename to src/flash/graph/__init__.py diff --git a/flash/graph/backbones.py b/src/flash/graph/backbones.py similarity index 100% rename from flash/graph/backbones.py rename to src/flash/graph/backbones.py diff --git a/flash/graph/classification/__init__.py b/src/flash/graph/classification/__init__.py similarity index 100% rename from flash/graph/classification/__init__.py rename to src/flash/graph/classification/__init__.py diff --git a/flash/graph/classification/cli.py b/src/flash/graph/classification/cli.py similarity index 100% rename from flash/graph/classification/cli.py rename to src/flash/graph/classification/cli.py diff --git a/flash/graph/classification/data.py b/src/flash/graph/classification/data.py similarity index 100% rename from flash/graph/classification/data.py rename to src/flash/graph/classification/data.py diff --git a/flash/graph/classification/input.py b/src/flash/graph/classification/input.py similarity index 100% rename from flash/graph/classification/input.py rename to src/flash/graph/classification/input.py diff --git a/flash/graph/classification/input_transform.py b/src/flash/graph/classification/input_transform.py similarity index 100% rename from flash/graph/classification/input_transform.py rename to src/flash/graph/classification/input_transform.py diff --git a/flash/graph/classification/model.py b/src/flash/graph/classification/model.py similarity index 100% rename from flash/graph/classification/model.py rename to src/flash/graph/classification/model.py diff --git a/flash/graph/collate.py b/src/flash/graph/collate.py similarity index 100% rename from flash/graph/collate.py rename to src/flash/graph/collate.py diff --git a/flash/graph/embedding/__init__.py b/src/flash/graph/embedding/__init__.py similarity index 100% rename from flash/graph/embedding/__init__.py rename to src/flash/graph/embedding/__init__.py diff --git a/flash/graph/embedding/model.py b/src/flash/graph/embedding/model.py similarity index 100% rename from flash/graph/embedding/model.py rename to src/flash/graph/embedding/model.py diff --git a/flash/image/__init__.py b/src/flash/image/__init__.py similarity index 100% rename from flash/image/__init__.py rename to src/flash/image/__init__.py diff --git a/flash/image/classification/__init__.py b/src/flash/image/classification/__init__.py similarity index 100% rename from flash/image/classification/__init__.py rename to src/flash/image/classification/__init__.py diff --git a/flash/image/classification/adapters.py b/src/flash/image/classification/adapters.py similarity index 100% rename from flash/image/classification/adapters.py rename to src/flash/image/classification/adapters.py diff --git a/flash/image/classification/backbones/__init__.py b/src/flash/image/classification/backbones/__init__.py similarity index 100% rename from flash/image/classification/backbones/__init__.py rename to src/flash/image/classification/backbones/__init__.py diff --git a/flash/image/classification/backbones/clip.py b/src/flash/image/classification/backbones/clip.py similarity index 100% rename from flash/image/classification/backbones/clip.py rename to src/flash/image/classification/backbones/clip.py diff --git a/flash/image/classification/backbones/resnet.py b/src/flash/image/classification/backbones/resnet.py similarity index 100% rename from flash/image/classification/backbones/resnet.py rename to src/flash/image/classification/backbones/resnet.py diff --git a/flash/image/classification/backbones/timm.py b/src/flash/image/classification/backbones/timm.py similarity index 100% rename from flash/image/classification/backbones/timm.py rename to src/flash/image/classification/backbones/timm.py diff --git a/flash/image/classification/backbones/torchvision.py b/src/flash/image/classification/backbones/torchvision.py similarity index 100% rename from flash/image/classification/backbones/torchvision.py rename to src/flash/image/classification/backbones/torchvision.py diff --git a/flash/image/classification/backbones/transformers.py b/src/flash/image/classification/backbones/transformers.py similarity index 100% rename from flash/image/classification/backbones/transformers.py rename to src/flash/image/classification/backbones/transformers.py diff --git a/flash/image/classification/cli.py b/src/flash/image/classification/cli.py similarity index 100% rename from flash/image/classification/cli.py rename to src/flash/image/classification/cli.py diff --git a/flash/image/classification/data.py b/src/flash/image/classification/data.py similarity index 100% rename from flash/image/classification/data.py rename to src/flash/image/classification/data.py diff --git a/flash/image/classification/heads.py b/src/flash/image/classification/heads.py similarity index 100% rename from flash/image/classification/heads.py rename to src/flash/image/classification/heads.py diff --git a/flash/image/classification/input.py b/src/flash/image/classification/input.py similarity index 100% rename from flash/image/classification/input.py rename to src/flash/image/classification/input.py diff --git a/flash/image/classification/input_transform.py b/src/flash/image/classification/input_transform.py similarity index 100% rename from flash/image/classification/input_transform.py rename to src/flash/image/classification/input_transform.py diff --git a/flash/image/embedding/vissl/__init__.py b/src/flash/image/classification/integrations/__init__.py similarity index 100% rename from flash/image/embedding/vissl/__init__.py rename to src/flash/image/classification/integrations/__init__.py diff --git a/flash/image/classification/integrations/baal/__init__.py b/src/flash/image/classification/integrations/baal/__init__.py similarity index 100% rename from flash/image/classification/integrations/baal/__init__.py rename to src/flash/image/classification/integrations/baal/__init__.py diff --git a/flash/image/classification/integrations/baal/data.py b/src/flash/image/classification/integrations/baal/data.py similarity index 100% rename from flash/image/classification/integrations/baal/data.py rename to src/flash/image/classification/integrations/baal/data.py diff --git a/flash/image/classification/integrations/baal/dropout.py b/src/flash/image/classification/integrations/baal/dropout.py similarity index 100% rename from flash/image/classification/integrations/baal/dropout.py rename to src/flash/image/classification/integrations/baal/dropout.py diff --git a/flash/image/classification/integrations/baal/loop.py b/src/flash/image/classification/integrations/baal/loop.py similarity index 100% rename from flash/image/classification/integrations/baal/loop.py rename to src/flash/image/classification/integrations/baal/loop.py diff --git a/flash/image/classification/integrations/learn2learn.py b/src/flash/image/classification/integrations/learn2learn.py similarity index 100% rename from flash/image/classification/integrations/learn2learn.py rename to src/flash/image/classification/integrations/learn2learn.py diff --git a/flash/image/classification/model.py b/src/flash/image/classification/model.py similarity index 100% rename from flash/image/classification/model.py rename to src/flash/image/classification/model.py diff --git a/flash/image/data.py b/src/flash/image/data.py similarity index 100% rename from flash/image/data.py rename to src/flash/image/data.py diff --git a/flash/image/detection/__init__.py b/src/flash/image/detection/__init__.py similarity index 100% rename from flash/image/detection/__init__.py rename to src/flash/image/detection/__init__.py diff --git a/flash/image/detection/backbones.py b/src/flash/image/detection/backbones.py similarity index 100% rename from flash/image/detection/backbones.py rename to src/flash/image/detection/backbones.py diff --git a/flash/image/detection/cli.py b/src/flash/image/detection/cli.py similarity index 100% rename from flash/image/detection/cli.py rename to src/flash/image/detection/cli.py diff --git a/flash/image/detection/data.py b/src/flash/image/detection/data.py similarity index 100% rename from flash/image/detection/data.py rename to src/flash/image/detection/data.py diff --git a/flash/image/detection/input.py b/src/flash/image/detection/input.py similarity index 100% rename from flash/image/detection/input.py rename to src/flash/image/detection/input.py diff --git a/flash/image/detection/model.py b/src/flash/image/detection/model.py similarity index 100% rename from flash/image/detection/model.py rename to src/flash/image/detection/model.py diff --git a/flash/image/detection/output.py b/src/flash/image/detection/output.py similarity index 100% rename from flash/image/detection/output.py rename to src/flash/image/detection/output.py diff --git a/flash/image/embedding/__init__.py b/src/flash/image/embedding/__init__.py similarity index 100% rename from flash/image/embedding/__init__.py rename to src/flash/image/embedding/__init__.py diff --git a/flash/image/embedding/heads/__init__.py b/src/flash/image/embedding/heads/__init__.py similarity index 100% rename from flash/image/embedding/heads/__init__.py rename to src/flash/image/embedding/heads/__init__.py diff --git a/flash/image/embedding/heads/vissl_heads.py b/src/flash/image/embedding/heads/vissl_heads.py similarity index 100% rename from flash/image/embedding/heads/vissl_heads.py rename to src/flash/image/embedding/heads/vissl_heads.py diff --git a/flash/image/embedding/losses/__init__.py b/src/flash/image/embedding/losses/__init__.py similarity index 100% rename from flash/image/embedding/losses/__init__.py rename to src/flash/image/embedding/losses/__init__.py diff --git a/flash/image/embedding/losses/vissl_losses.py b/src/flash/image/embedding/losses/vissl_losses.py similarity index 100% rename from flash/image/embedding/losses/vissl_losses.py rename to src/flash/image/embedding/losses/vissl_losses.py diff --git a/flash/image/embedding/model.py b/src/flash/image/embedding/model.py similarity index 100% rename from flash/image/embedding/model.py rename to src/flash/image/embedding/model.py diff --git a/flash/image/embedding/strategies/__init__.py b/src/flash/image/embedding/strategies/__init__.py similarity index 100% rename from flash/image/embedding/strategies/__init__.py rename to src/flash/image/embedding/strategies/__init__.py diff --git a/flash/image/embedding/strategies/default.py b/src/flash/image/embedding/strategies/default.py similarity index 100% rename from flash/image/embedding/strategies/default.py rename to src/flash/image/embedding/strategies/default.py diff --git a/flash/image/embedding/strategies/vissl_strategies.py b/src/flash/image/embedding/strategies/vissl_strategies.py similarity index 100% rename from flash/image/embedding/strategies/vissl_strategies.py rename to src/flash/image/embedding/strategies/vissl_strategies.py diff --git a/flash/image/embedding/transforms/__init__.py b/src/flash/image/embedding/transforms/__init__.py similarity index 100% rename from flash/image/embedding/transforms/__init__.py rename to src/flash/image/embedding/transforms/__init__.py diff --git a/flash/image/embedding/transforms/vissl_transforms.py b/src/flash/image/embedding/transforms/vissl_transforms.py similarity index 100% rename from flash/image/embedding/transforms/vissl_transforms.py rename to src/flash/image/embedding/transforms/vissl_transforms.py diff --git a/flash/pointcloud/detection/open3d_ml/__init__.py b/src/flash/image/embedding/vissl/__init__.py similarity index 100% rename from flash/pointcloud/detection/open3d_ml/__init__.py rename to src/flash/image/embedding/vissl/__init__.py diff --git a/flash/image/embedding/vissl/adapter.py b/src/flash/image/embedding/vissl/adapter.py similarity index 100% rename from flash/image/embedding/vissl/adapter.py rename to src/flash/image/embedding/vissl/adapter.py diff --git a/flash/image/embedding/vissl/hooks.py b/src/flash/image/embedding/vissl/hooks.py similarity index 100% rename from flash/image/embedding/vissl/hooks.py rename to src/flash/image/embedding/vissl/hooks.py diff --git a/flash/image/embedding/vissl/transforms/__init__.py b/src/flash/image/embedding/vissl/transforms/__init__.py similarity index 100% rename from flash/image/embedding/vissl/transforms/__init__.py rename to src/flash/image/embedding/vissl/transforms/__init__.py diff --git a/flash/image/embedding/vissl/transforms/multicrop.py b/src/flash/image/embedding/vissl/transforms/multicrop.py similarity index 100% rename from flash/image/embedding/vissl/transforms/multicrop.py rename to src/flash/image/embedding/vissl/transforms/multicrop.py diff --git a/flash/image/embedding/vissl/transforms/utilities.py b/src/flash/image/embedding/vissl/transforms/utilities.py similarity index 100% rename from flash/image/embedding/vissl/transforms/utilities.py rename to src/flash/image/embedding/vissl/transforms/utilities.py diff --git a/flash/image/face_detection/__init__.py b/src/flash/image/face_detection/__init__.py similarity index 100% rename from flash/image/face_detection/__init__.py rename to src/flash/image/face_detection/__init__.py diff --git a/flash/image/face_detection/backbones/__init__.py b/src/flash/image/face_detection/backbones/__init__.py similarity index 100% rename from flash/image/face_detection/backbones/__init__.py rename to src/flash/image/face_detection/backbones/__init__.py diff --git a/flash/image/face_detection/backbones/fastface_backbones.py b/src/flash/image/face_detection/backbones/fastface_backbones.py similarity index 100% rename from flash/image/face_detection/backbones/fastface_backbones.py rename to src/flash/image/face_detection/backbones/fastface_backbones.py diff --git a/flash/image/face_detection/cli.py b/src/flash/image/face_detection/cli.py similarity index 100% rename from flash/image/face_detection/cli.py rename to src/flash/image/face_detection/cli.py diff --git a/flash/image/face_detection/data.py b/src/flash/image/face_detection/data.py similarity index 100% rename from flash/image/face_detection/data.py rename to src/flash/image/face_detection/data.py diff --git a/flash/image/face_detection/input.py b/src/flash/image/face_detection/input.py similarity index 100% rename from flash/image/face_detection/input.py rename to src/flash/image/face_detection/input.py diff --git a/flash/image/face_detection/input_transform.py b/src/flash/image/face_detection/input_transform.py similarity index 100% rename from flash/image/face_detection/input_transform.py rename to src/flash/image/face_detection/input_transform.py diff --git a/flash/image/face_detection/model.py b/src/flash/image/face_detection/model.py similarity index 100% rename from flash/image/face_detection/model.py rename to src/flash/image/face_detection/model.py diff --git a/flash/image/face_detection/output_transform.py b/src/flash/image/face_detection/output_transform.py similarity index 100% rename from flash/image/face_detection/output_transform.py rename to src/flash/image/face_detection/output_transform.py diff --git a/flash/image/instance_segmentation/__init__.py b/src/flash/image/instance_segmentation/__init__.py similarity index 100% rename from flash/image/instance_segmentation/__init__.py rename to src/flash/image/instance_segmentation/__init__.py diff --git a/flash/image/instance_segmentation/backbones.py b/src/flash/image/instance_segmentation/backbones.py similarity index 100% rename from flash/image/instance_segmentation/backbones.py rename to src/flash/image/instance_segmentation/backbones.py diff --git a/flash/image/instance_segmentation/cli.py b/src/flash/image/instance_segmentation/cli.py similarity index 100% rename from flash/image/instance_segmentation/cli.py rename to src/flash/image/instance_segmentation/cli.py diff --git a/flash/image/instance_segmentation/data.py b/src/flash/image/instance_segmentation/data.py similarity index 100% rename from flash/image/instance_segmentation/data.py rename to src/flash/image/instance_segmentation/data.py diff --git a/flash/image/instance_segmentation/model.py b/src/flash/image/instance_segmentation/model.py similarity index 100% rename from flash/image/instance_segmentation/model.py rename to src/flash/image/instance_segmentation/model.py diff --git a/flash/image/keypoint_detection/__init__.py b/src/flash/image/keypoint_detection/__init__.py similarity index 100% rename from flash/image/keypoint_detection/__init__.py rename to src/flash/image/keypoint_detection/__init__.py diff --git a/flash/image/keypoint_detection/backbones.py b/src/flash/image/keypoint_detection/backbones.py similarity index 100% rename from flash/image/keypoint_detection/backbones.py rename to src/flash/image/keypoint_detection/backbones.py diff --git a/flash/image/keypoint_detection/cli.py b/src/flash/image/keypoint_detection/cli.py similarity index 100% rename from flash/image/keypoint_detection/cli.py rename to src/flash/image/keypoint_detection/cli.py diff --git a/flash/image/keypoint_detection/data.py b/src/flash/image/keypoint_detection/data.py similarity index 100% rename from flash/image/keypoint_detection/data.py rename to src/flash/image/keypoint_detection/data.py diff --git a/flash/image/keypoint_detection/input_transform.py b/src/flash/image/keypoint_detection/input_transform.py similarity index 100% rename from flash/image/keypoint_detection/input_transform.py rename to src/flash/image/keypoint_detection/input_transform.py diff --git a/flash/image/keypoint_detection/model.py b/src/flash/image/keypoint_detection/model.py similarity index 100% rename from flash/image/keypoint_detection/model.py rename to src/flash/image/keypoint_detection/model.py diff --git a/flash/image/segmentation/__init__.py b/src/flash/image/segmentation/__init__.py similarity index 100% rename from flash/image/segmentation/__init__.py rename to src/flash/image/segmentation/__init__.py diff --git a/flash/image/segmentation/backbones.py b/src/flash/image/segmentation/backbones.py similarity index 100% rename from flash/image/segmentation/backbones.py rename to src/flash/image/segmentation/backbones.py diff --git a/flash/image/segmentation/cli.py b/src/flash/image/segmentation/cli.py similarity index 100% rename from flash/image/segmentation/cli.py rename to src/flash/image/segmentation/cli.py diff --git a/flash/image/segmentation/data.py b/src/flash/image/segmentation/data.py similarity index 100% rename from flash/image/segmentation/data.py rename to src/flash/image/segmentation/data.py diff --git a/flash/image/segmentation/heads.py b/src/flash/image/segmentation/heads.py similarity index 100% rename from flash/image/segmentation/heads.py rename to src/flash/image/segmentation/heads.py diff --git a/flash/image/segmentation/input.py b/src/flash/image/segmentation/input.py similarity index 100% rename from flash/image/segmentation/input.py rename to src/flash/image/segmentation/input.py diff --git a/flash/image/segmentation/input_transform.py b/src/flash/image/segmentation/input_transform.py similarity index 100% rename from flash/image/segmentation/input_transform.py rename to src/flash/image/segmentation/input_transform.py diff --git a/flash/image/segmentation/model.py b/src/flash/image/segmentation/model.py similarity index 100% rename from flash/image/segmentation/model.py rename to src/flash/image/segmentation/model.py diff --git a/flash/image/segmentation/output.py b/src/flash/image/segmentation/output.py similarity index 100% rename from flash/image/segmentation/output.py rename to src/flash/image/segmentation/output.py diff --git a/flash/image/segmentation/viz.py b/src/flash/image/segmentation/viz.py similarity index 100% rename from flash/image/segmentation/viz.py rename to src/flash/image/segmentation/viz.py diff --git a/flash/image/style_transfer/__init__.py b/src/flash/image/style_transfer/__init__.py similarity index 100% rename from flash/image/style_transfer/__init__.py rename to src/flash/image/style_transfer/__init__.py diff --git a/flash/image/style_transfer/backbones.py b/src/flash/image/style_transfer/backbones.py similarity index 100% rename from flash/image/style_transfer/backbones.py rename to src/flash/image/style_transfer/backbones.py diff --git a/flash/image/style_transfer/cli.py b/src/flash/image/style_transfer/cli.py similarity index 100% rename from flash/image/style_transfer/cli.py rename to src/flash/image/style_transfer/cli.py diff --git a/flash/image/style_transfer/data.py b/src/flash/image/style_transfer/data.py similarity index 100% rename from flash/image/style_transfer/data.py rename to src/flash/image/style_transfer/data.py diff --git a/flash/image/style_transfer/input_transform.py b/src/flash/image/style_transfer/input_transform.py similarity index 100% rename from flash/image/style_transfer/input_transform.py rename to src/flash/image/style_transfer/input_transform.py diff --git a/flash/image/style_transfer/model.py b/src/flash/image/style_transfer/model.py similarity index 100% rename from flash/image/style_transfer/model.py rename to src/flash/image/style_transfer/model.py diff --git a/flash/image/style_transfer/utils.py b/src/flash/image/style_transfer/utils.py similarity index 100% rename from flash/image/style_transfer/utils.py rename to src/flash/image/style_transfer/utils.py diff --git a/flash/pointcloud/__init__.py b/src/flash/pointcloud/__init__.py similarity index 100% rename from flash/pointcloud/__init__.py rename to src/flash/pointcloud/__init__.py diff --git a/flash/pointcloud/detection/__init__.py b/src/flash/pointcloud/detection/__init__.py similarity index 100% rename from flash/pointcloud/detection/__init__.py rename to src/flash/pointcloud/detection/__init__.py diff --git a/flash/pointcloud/detection/backbones.py b/src/flash/pointcloud/detection/backbones.py similarity index 100% rename from flash/pointcloud/detection/backbones.py rename to src/flash/pointcloud/detection/backbones.py diff --git a/flash/pointcloud/detection/cli.py b/src/flash/pointcloud/detection/cli.py similarity index 100% rename from flash/pointcloud/detection/cli.py rename to src/flash/pointcloud/detection/cli.py diff --git a/flash/pointcloud/detection/data.py b/src/flash/pointcloud/detection/data.py similarity index 100% rename from flash/pointcloud/detection/data.py rename to src/flash/pointcloud/detection/data.py diff --git a/flash/pointcloud/detection/datasets.py b/src/flash/pointcloud/detection/datasets.py similarity index 100% rename from flash/pointcloud/detection/datasets.py rename to src/flash/pointcloud/detection/datasets.py diff --git a/flash/pointcloud/detection/input.py b/src/flash/pointcloud/detection/input.py similarity index 100% rename from flash/pointcloud/detection/input.py rename to src/flash/pointcloud/detection/input.py diff --git a/flash/pointcloud/detection/model.py b/src/flash/pointcloud/detection/model.py similarity index 100% rename from flash/pointcloud/detection/model.py rename to src/flash/pointcloud/detection/model.py diff --git a/flash/pointcloud/segmentation/open3d_ml/__init__.py b/src/flash/pointcloud/detection/open3d_ml/__init__.py similarity index 100% rename from flash/pointcloud/segmentation/open3d_ml/__init__.py rename to src/flash/pointcloud/detection/open3d_ml/__init__.py diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/src/flash/pointcloud/detection/open3d_ml/app.py similarity index 100% rename from flash/pointcloud/detection/open3d_ml/app.py rename to src/flash/pointcloud/detection/open3d_ml/app.py diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/src/flash/pointcloud/detection/open3d_ml/backbones.py similarity index 100% rename from flash/pointcloud/detection/open3d_ml/backbones.py rename to src/flash/pointcloud/detection/open3d_ml/backbones.py diff --git a/flash/pointcloud/detection/open3d_ml/input.py b/src/flash/pointcloud/detection/open3d_ml/input.py similarity index 100% rename from flash/pointcloud/detection/open3d_ml/input.py rename to src/flash/pointcloud/detection/open3d_ml/input.py diff --git a/flash/pointcloud/segmentation/__init__.py b/src/flash/pointcloud/segmentation/__init__.py similarity index 100% rename from flash/pointcloud/segmentation/__init__.py rename to src/flash/pointcloud/segmentation/__init__.py diff --git a/flash/pointcloud/segmentation/backbones.py b/src/flash/pointcloud/segmentation/backbones.py similarity index 100% rename from flash/pointcloud/segmentation/backbones.py rename to src/flash/pointcloud/segmentation/backbones.py diff --git a/flash/pointcloud/segmentation/cli.py b/src/flash/pointcloud/segmentation/cli.py similarity index 100% rename from flash/pointcloud/segmentation/cli.py rename to src/flash/pointcloud/segmentation/cli.py diff --git a/flash/pointcloud/segmentation/data.py b/src/flash/pointcloud/segmentation/data.py similarity index 100% rename from flash/pointcloud/segmentation/data.py rename to src/flash/pointcloud/segmentation/data.py diff --git a/flash/pointcloud/segmentation/datasets.py b/src/flash/pointcloud/segmentation/datasets.py similarity index 100% rename from flash/pointcloud/segmentation/datasets.py rename to src/flash/pointcloud/segmentation/datasets.py diff --git a/flash/pointcloud/segmentation/input.py b/src/flash/pointcloud/segmentation/input.py similarity index 100% rename from flash/pointcloud/segmentation/input.py rename to src/flash/pointcloud/segmentation/input.py diff --git a/flash/pointcloud/segmentation/model.py b/src/flash/pointcloud/segmentation/model.py similarity index 100% rename from flash/pointcloud/segmentation/model.py rename to src/flash/pointcloud/segmentation/model.py diff --git a/flash/video/classification/__init__.py b/src/flash/pointcloud/segmentation/open3d_ml/__init__.py similarity index 100% rename from flash/video/classification/__init__.py rename to src/flash/pointcloud/segmentation/open3d_ml/__init__.py diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/src/flash/pointcloud/segmentation/open3d_ml/app.py similarity index 100% rename from flash/pointcloud/segmentation/open3d_ml/app.py rename to src/flash/pointcloud/segmentation/open3d_ml/app.py diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/src/flash/pointcloud/segmentation/open3d_ml/backbones.py similarity index 100% rename from flash/pointcloud/segmentation/open3d_ml/backbones.py rename to src/flash/pointcloud/segmentation/open3d_ml/backbones.py diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py similarity index 100% rename from flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py rename to src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py diff --git a/flash/tabular/__init__.py b/src/flash/tabular/__init__.py similarity index 100% rename from flash/tabular/__init__.py rename to src/flash/tabular/__init__.py diff --git a/flash/tabular/classification/__init__.py b/src/flash/tabular/classification/__init__.py similarity index 100% rename from flash/tabular/classification/__init__.py rename to src/flash/tabular/classification/__init__.py diff --git a/flash/tabular/classification/cli.py b/src/flash/tabular/classification/cli.py similarity index 100% rename from flash/tabular/classification/cli.py rename to src/flash/tabular/classification/cli.py diff --git a/flash/tabular/classification/data.py b/src/flash/tabular/classification/data.py similarity index 100% rename from flash/tabular/classification/data.py rename to src/flash/tabular/classification/data.py diff --git a/flash/tabular/classification/input.py b/src/flash/tabular/classification/input.py similarity index 100% rename from flash/tabular/classification/input.py rename to src/flash/tabular/classification/input.py diff --git a/flash/tabular/classification/model.py b/src/flash/tabular/classification/model.py similarity index 100% rename from flash/tabular/classification/model.py rename to src/flash/tabular/classification/model.py diff --git a/flash/tabular/classification/utils.py b/src/flash/tabular/classification/utils.py similarity index 100% rename from flash/tabular/classification/utils.py rename to src/flash/tabular/classification/utils.py diff --git a/flash/tabular/data.py b/src/flash/tabular/data.py similarity index 100% rename from flash/tabular/data.py rename to src/flash/tabular/data.py diff --git a/flash/tabular/forecasting/__init__.py b/src/flash/tabular/forecasting/__init__.py similarity index 100% rename from flash/tabular/forecasting/__init__.py rename to src/flash/tabular/forecasting/__init__.py diff --git a/flash/tabular/forecasting/cli.py b/src/flash/tabular/forecasting/cli.py similarity index 100% rename from flash/tabular/forecasting/cli.py rename to src/flash/tabular/forecasting/cli.py diff --git a/flash/tabular/forecasting/data.py b/src/flash/tabular/forecasting/data.py similarity index 100% rename from flash/tabular/forecasting/data.py rename to src/flash/tabular/forecasting/data.py diff --git a/flash/tabular/forecasting/input.py b/src/flash/tabular/forecasting/input.py similarity index 100% rename from flash/tabular/forecasting/input.py rename to src/flash/tabular/forecasting/input.py diff --git a/flash/tabular/forecasting/model.py b/src/flash/tabular/forecasting/model.py similarity index 100% rename from flash/tabular/forecasting/model.py rename to src/flash/tabular/forecasting/model.py diff --git a/flash/tabular/input.py b/src/flash/tabular/input.py similarity index 100% rename from flash/tabular/input.py rename to src/flash/tabular/input.py diff --git a/flash/tabular/regression/__init__.py b/src/flash/tabular/regression/__init__.py similarity index 100% rename from flash/tabular/regression/__init__.py rename to src/flash/tabular/regression/__init__.py diff --git a/flash/tabular/regression/cli.py b/src/flash/tabular/regression/cli.py similarity index 100% rename from flash/tabular/regression/cli.py rename to src/flash/tabular/regression/cli.py diff --git a/flash/tabular/regression/data.py b/src/flash/tabular/regression/data.py similarity index 100% rename from flash/tabular/regression/data.py rename to src/flash/tabular/regression/data.py diff --git a/flash/tabular/regression/input.py b/src/flash/tabular/regression/input.py similarity index 100% rename from flash/tabular/regression/input.py rename to src/flash/tabular/regression/input.py diff --git a/flash/tabular/regression/model.py b/src/flash/tabular/regression/model.py similarity index 100% rename from flash/tabular/regression/model.py rename to src/flash/tabular/regression/model.py diff --git a/flash/template/__init__.py b/src/flash/template/__init__.py similarity index 100% rename from flash/template/__init__.py rename to src/flash/template/__init__.py diff --git a/flash/template/classification/__init__.py b/src/flash/template/classification/__init__.py similarity index 100% rename from flash/template/classification/__init__.py rename to src/flash/template/classification/__init__.py diff --git a/flash/template/classification/backbones.py b/src/flash/template/classification/backbones.py similarity index 100% rename from flash/template/classification/backbones.py rename to src/flash/template/classification/backbones.py diff --git a/flash/template/classification/data.py b/src/flash/template/classification/data.py similarity index 100% rename from flash/template/classification/data.py rename to src/flash/template/classification/data.py diff --git a/flash/template/classification/model.py b/src/flash/template/classification/model.py similarity index 100% rename from flash/template/classification/model.py rename to src/flash/template/classification/model.py diff --git a/flash/text/__init__.py b/src/flash/text/__init__.py similarity index 100% rename from flash/text/__init__.py rename to src/flash/text/__init__.py diff --git a/flash/text/classification/__init__.py b/src/flash/text/classification/__init__.py similarity index 100% rename from flash/text/classification/__init__.py rename to src/flash/text/classification/__init__.py diff --git a/flash/text/classification/adapters.py b/src/flash/text/classification/adapters.py similarity index 100% rename from flash/text/classification/adapters.py rename to src/flash/text/classification/adapters.py diff --git a/flash/text/classification/backbones/__init__.py b/src/flash/text/classification/backbones/__init__.py similarity index 100% rename from flash/text/classification/backbones/__init__.py rename to src/flash/text/classification/backbones/__init__.py diff --git a/flash/text/classification/backbones/clip.py b/src/flash/text/classification/backbones/clip.py similarity index 100% rename from flash/text/classification/backbones/clip.py rename to src/flash/text/classification/backbones/clip.py diff --git a/flash/text/classification/backbones/huggingface.py b/src/flash/text/classification/backbones/huggingface.py similarity index 100% rename from flash/text/classification/backbones/huggingface.py rename to src/flash/text/classification/backbones/huggingface.py diff --git a/flash/text/classification/cli.py b/src/flash/text/classification/cli.py similarity index 100% rename from flash/text/classification/cli.py rename to src/flash/text/classification/cli.py diff --git a/flash/text/classification/collate.py b/src/flash/text/classification/collate.py similarity index 100% rename from flash/text/classification/collate.py rename to src/flash/text/classification/collate.py diff --git a/flash/text/classification/data.py b/src/flash/text/classification/data.py similarity index 100% rename from flash/text/classification/data.py rename to src/flash/text/classification/data.py diff --git a/flash/text/classification/input.py b/src/flash/text/classification/input.py similarity index 100% rename from flash/text/classification/input.py rename to src/flash/text/classification/input.py diff --git a/flash/text/classification/model.py b/src/flash/text/classification/model.py similarity index 100% rename from flash/text/classification/model.py rename to src/flash/text/classification/model.py diff --git a/flash/text/embedding/__init__.py b/src/flash/text/embedding/__init__.py similarity index 100% rename from flash/text/embedding/__init__.py rename to src/flash/text/embedding/__init__.py diff --git a/flash/text/embedding/backbones.py b/src/flash/text/embedding/backbones.py similarity index 100% rename from flash/text/embedding/backbones.py rename to src/flash/text/embedding/backbones.py diff --git a/flash/text/embedding/model.py b/src/flash/text/embedding/model.py similarity index 100% rename from flash/text/embedding/model.py rename to src/flash/text/embedding/model.py diff --git a/flash/text/input.py b/src/flash/text/input.py similarity index 100% rename from flash/text/input.py rename to src/flash/text/input.py diff --git a/flash/text/ort_callback.py b/src/flash/text/ort_callback.py similarity index 100% rename from flash/text/ort_callback.py rename to src/flash/text/ort_callback.py diff --git a/flash/text/question_answering/__init__.py b/src/flash/text/question_answering/__init__.py similarity index 100% rename from flash/text/question_answering/__init__.py rename to src/flash/text/question_answering/__init__.py diff --git a/flash/text/question_answering/cli.py b/src/flash/text/question_answering/cli.py similarity index 100% rename from flash/text/question_answering/cli.py rename to src/flash/text/question_answering/cli.py diff --git a/flash/text/question_answering/collate.py b/src/flash/text/question_answering/collate.py similarity index 100% rename from flash/text/question_answering/collate.py rename to src/flash/text/question_answering/collate.py diff --git a/flash/text/question_answering/data.py b/src/flash/text/question_answering/data.py similarity index 100% rename from flash/text/question_answering/data.py rename to src/flash/text/question_answering/data.py diff --git a/flash/text/question_answering/input.py b/src/flash/text/question_answering/input.py similarity index 100% rename from flash/text/question_answering/input.py rename to src/flash/text/question_answering/input.py diff --git a/flash/text/question_answering/model.py b/src/flash/text/question_answering/model.py similarity index 100% rename from flash/text/question_answering/model.py rename to src/flash/text/question_answering/model.py diff --git a/flash/text/question_answering/output_transform.py b/src/flash/text/question_answering/output_transform.py similarity index 100% rename from flash/text/question_answering/output_transform.py rename to src/flash/text/question_answering/output_transform.py diff --git a/flash/text/seq2seq/__init__.py b/src/flash/text/seq2seq/__init__.py similarity index 100% rename from flash/text/seq2seq/__init__.py rename to src/flash/text/seq2seq/__init__.py diff --git a/flash/text/seq2seq/core/__init__.py b/src/flash/text/seq2seq/core/__init__.py similarity index 100% rename from flash/text/seq2seq/core/__init__.py rename to src/flash/text/seq2seq/core/__init__.py diff --git a/flash/text/seq2seq/core/collate.py b/src/flash/text/seq2seq/core/collate.py similarity index 100% rename from flash/text/seq2seq/core/collate.py rename to src/flash/text/seq2seq/core/collate.py diff --git a/flash/text/seq2seq/core/input.py b/src/flash/text/seq2seq/core/input.py similarity index 100% rename from flash/text/seq2seq/core/input.py rename to src/flash/text/seq2seq/core/input.py diff --git a/flash/text/seq2seq/core/model.py b/src/flash/text/seq2seq/core/model.py similarity index 100% rename from flash/text/seq2seq/core/model.py rename to src/flash/text/seq2seq/core/model.py diff --git a/flash/text/seq2seq/summarization/__init__.py b/src/flash/text/seq2seq/summarization/__init__.py similarity index 100% rename from flash/text/seq2seq/summarization/__init__.py rename to src/flash/text/seq2seq/summarization/__init__.py diff --git a/flash/text/seq2seq/summarization/cli.py b/src/flash/text/seq2seq/summarization/cli.py similarity index 100% rename from flash/text/seq2seq/summarization/cli.py rename to src/flash/text/seq2seq/summarization/cli.py diff --git a/flash/text/seq2seq/summarization/data.py b/src/flash/text/seq2seq/summarization/data.py similarity index 100% rename from flash/text/seq2seq/summarization/data.py rename to src/flash/text/seq2seq/summarization/data.py diff --git a/flash/text/seq2seq/summarization/model.py b/src/flash/text/seq2seq/summarization/model.py similarity index 100% rename from flash/text/seq2seq/summarization/model.py rename to src/flash/text/seq2seq/summarization/model.py diff --git a/flash/text/seq2seq/translation/__init__.py b/src/flash/text/seq2seq/translation/__init__.py similarity index 100% rename from flash/text/seq2seq/translation/__init__.py rename to src/flash/text/seq2seq/translation/__init__.py diff --git a/flash/text/seq2seq/translation/cli.py b/src/flash/text/seq2seq/translation/cli.py similarity index 100% rename from flash/text/seq2seq/translation/cli.py rename to src/flash/text/seq2seq/translation/cli.py diff --git a/flash/text/seq2seq/translation/data.py b/src/flash/text/seq2seq/translation/data.py similarity index 100% rename from flash/text/seq2seq/translation/data.py rename to src/flash/text/seq2seq/translation/data.py diff --git a/flash/text/seq2seq/translation/model.py b/src/flash/text/seq2seq/translation/model.py similarity index 100% rename from flash/text/seq2seq/translation/model.py rename to src/flash/text/seq2seq/translation/model.py diff --git a/flash/video/__init__.py b/src/flash/video/__init__.py similarity index 100% rename from flash/video/__init__.py rename to src/flash/video/__init__.py diff --git a/flash_examples/serve/image_classification/__init__.py b/src/flash/video/classification/__init__.py similarity index 100% rename from flash_examples/serve/image_classification/__init__.py rename to src/flash/video/classification/__init__.py diff --git a/flash/video/classification/cli.py b/src/flash/video/classification/cli.py similarity index 100% rename from flash/video/classification/cli.py rename to src/flash/video/classification/cli.py diff --git a/flash/video/classification/data.py b/src/flash/video/classification/data.py similarity index 100% rename from flash/video/classification/data.py rename to src/flash/video/classification/data.py diff --git a/flash/video/classification/input.py b/src/flash/video/classification/input.py similarity index 100% rename from flash/video/classification/input.py rename to src/flash/video/classification/input.py diff --git a/flash/video/classification/input_transform.py b/src/flash/video/classification/input_transform.py similarity index 100% rename from flash/video/classification/input_transform.py rename to src/flash/video/classification/input_transform.py diff --git a/flash/video/classification/model.py b/src/flash/video/classification/model.py similarity index 100% rename from flash/video/classification/model.py rename to src/flash/video/classification/model.py diff --git a/flash/video/classification/utils.py b/src/flash/video/classification/utils.py similarity index 100% rename from flash/video/classification/utils.py rename to src/flash/video/classification/utils.py diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index 198ff5106a..51a79b71ae 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -3,7 +3,12 @@ import pytest from flash.core.serve import Composition, Endpoint -from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_AVAILABLE, _SERVE_TESTING + +if _SERVE_AVAILABLE: + from jinja2 import TemplateNotFound +else: + TemplateNotFound = ... if _FASTAPI_AVAILABLE: from fastapi.testclient import TestClient @@ -154,6 +159,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.xfail(TemplateNotFound, reason="jinja2.exceptions.TemplateNotFound: dag.html") # todo def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -208,6 +214,7 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.xfail(TemplateNotFound, reason="jinja2.exceptions.TemplateNotFound: dag.html") # todo def test_composed_does_not_eliminate_endpoint_serialization(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -283,6 +290,7 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.xfail(TemplateNotFound, reason="jinja2.exceptions.TemplateNotFound: dag.html") # todo def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py index 4588c1445b..035e1ec7ea 100644 --- a/tests/examples/test_integrations.py +++ b/tests/examples/test_integrations.py @@ -59,4 +59,4 @@ ], ) def test_integrations(tmpdir, folder, file): - run_test(str(root / "flash_examples" / "integrations" / folder / file)) + run_test(str(root / "examples" / "integrations" / folder / file)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 8337fad2c1..e423c34b2c 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -167,4 +167,4 @@ ) @forked def test_example(tmpdir, file): - run_test(str(root / "flash_examples" / file)) + run_test(str(root / "examples" / file)) diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index 7ce7a8dd10..37b27580e3 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any +from urllib.error import URLError import pytest import torch @@ -23,6 +24,7 @@ from tests.helpers.task_tester import TaskTester +@pytest.mark.xfail(URLError, reason="Connection timed out for download.pystiche.org") class TestStyleTransfer(TaskTester): task = StyleTransfer cli_command = "style_transfer" @@ -47,6 +49,7 @@ def example_train_sample(self): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.xfail(URLError, reason="Connection timed out for download.pystiche.org") def test_style_transfer_task(): model = StyleTransfer( backbone="vgg11", content_layer="relu1_2", content_weight=10, style_layers="relu1_2", style_weight=11 From 73e4ce911613b76521bf72099db3896c78a48828 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 24 Mar 2023 12:46:20 +0100 Subject: [PATCH 15/39] switch to pyproject & ruff (#1537) * pyproject * ruff * fixing * toml * warning --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 32 +++-- .../boston_prediction/inference_server.py | 2 +- examples/serve/generic/detection/inference.py | 2 +- .../visualizations/pointcloud_detection.py | 2 +- .../visualizations/pointcloud_segmentation.py | 2 +- pyproject.toml | 113 +++++++++++++++++- requirements/test.txt | 4 +- setup.cfg | 101 ---------------- src/flash/audio/classification/input.py | 4 +- src/flash/core/adapter.py | 2 +- src/flash/core/data/batch.py | 4 +- src/flash/core/data/data_module.py | 2 +- .../core/data/io/classification_input.py | 2 +- src/flash/core/data/io/input.py | 2 +- .../core/data/utilities/classification.py | 2 +- src/flash/core/data/utilities/paths.py | 2 +- src/flash/core/integrations/fiftyone/utils.py | 2 +- .../core/integrations/icevision/adapter.py | 2 +- src/flash/core/integrations/icevision/data.py | 2 +- .../core/integrations/icevision/transforms.py | 2 +- src/flash/core/model.py | 4 +- src/flash/core/optimizers/lamb.py | 1 - src/flash/core/optimizers/lars.py | 1 - src/flash/core/optimizers/lr_scheduler.py | 3 +- src/flash/core/optimizers/schedulers.py | 2 +- src/flash/core/serve/core.py | 2 +- src/flash/core/serve/dag/optimization.py | 1 - src/flash/core/serve/dag/order.py | 3 +- src/flash/core/serve/dag/task.py | 1 - src/flash/core/serve/decorators.py | 2 +- src/flash/core/serve/execution.py | 2 +- src/flash/core/serve/flash_components.py | 2 +- src/flash/core/serve/interfaces/http.py | 8 +- src/flash/core/utilities/flash_cli.py | 2 +- src/flash/core/utilities/lightning_cli.py | 2 +- src/flash/graph/classification/model.py | 4 +- src/flash/graph/embedding/model.py | 6 +- src/flash/image/classification/adapters.py | 2 +- src/flash/image/classification/data.py | 2 +- src/flash/image/classification/input.py | 6 +- .../classification/integrations/baal/data.py | 2 +- src/flash/image/data.py | 4 +- src/flash/image/detection/input.py | 6 +- .../image/embedding/losses/vissl_losses.py | 2 +- src/flash/image/segmentation/input.py | 4 +- src/flash/image/segmentation/model.py | 2 +- src/flash/image/style_transfer/model.py | 4 +- src/flash/pointcloud/detection/model.py | 2 +- .../pointcloud/detection/open3d_ml/input.py | 2 +- src/flash/pointcloud/segmentation/model.py | 2 +- src/flash/template/classification/model.py | 2 +- src/flash/text/classification/data.py | 2 +- src/flash/video/classification/data.py | 2 +- src/flash/video/classification/input.py | 4 +- .../data/utilities/test_classification.py | 2 +- tests/core/data/utilities/test_loading.py | 4 +- tests/core/data/utilities/test_paths.py | 2 +- .../labelstudio/test_labelstudio.py | 2 +- tests/core/serve/models.py | 2 +- tests/core/serve/test_components.py | 8 +- .../core/serve/test_dag/test_optimization.py | 2 +- tests/core/serve/test_dag/test_rewrite.py | 2 +- tests/core/serve/test_gridbase_validations.py | 2 +- tests/core/serve/test_integration.py | 2 +- tests/core/test_finetuning.py | 3 +- tests/core/test_model.py | 2 +- tests/core/utilities/test_lightning_cli.py | 2 +- 67 files changed, 212 insertions(+), 205 deletions(-) delete mode 100644 setup.cfg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e29597b17f..7573e42d35 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,9 +27,10 @@ repos: hooks: - id: end-of-file-fixer - id: trailing-whitespace + - id: check-json - id: check-yaml - - id: check-docstring-first - id: check-toml + - id: check-docstring-first - id: check-case-conflict - id: check-added-large-files - id: detect-private-key @@ -41,12 +42,6 @@ repos: args: [--py37-plus] name: Upgrade code - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - name: imports - - repo: https://github.com/kynan/nbstripout rev: 0.6.1 hooks: @@ -56,7 +51,10 @@ repos: rev: v1.5.0 hooks: - id: docformatter - args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] + args: + - "--in-place" + - "--wrap-summaries=120" + - "--wrap-descriptions=120" - repo: https://github.com/psf/black rev: 23.1.0 @@ -64,14 +62,22 @@ repos: - id: black name: Format code + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + name: imports + - repo: https://github.com/asottile/blacken-docs rev: 1.13.0 hooks: - id: blacken-docs - args: [ --line-length=120, --skip-errors ] + args: + - "--line-length=120" + - "--skip-errors" - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.240 hooks: - - id: flake8 - name: PEP8 + - id: ruff + args: ["--fix"] diff --git a/examples/serve/generic/boston_prediction/inference_server.py b/examples/serve/generic/boston_prediction/inference_server.py index acd1735ae9..995ec3917f 100644 --- a/examples/serve/generic/boston_prediction/inference_server.py +++ b/examples/serve/generic/boston_prediction/inference_server.py @@ -14,7 +14,7 @@ import hummingbird.ml import sklearn.datasets -from flash.core.serve import Composition, expose, ModelComponent +from flash.core.serve import Composition, ModelComponent, expose from flash.core.serve.types import Number, Table feature_names = [ diff --git a/examples/serve/generic/detection/inference.py b/examples/serve/generic/detection/inference.py index 813359a6dc..2ae25affa1 100644 --- a/examples/serve/generic/detection/inference.py +++ b/examples/serve/generic/detection/inference.py @@ -13,7 +13,7 @@ # limitations under the License. import torchvision -from flash.core.serve import Composition, expose, ModelComponent +from flash.core.serve import Composition, ModelComponent, expose from flash.core.serve.types import BBox, Image, Label, Repeated diff --git a/examples/visualizations/pointcloud_detection.py b/examples/visualizations/pointcloud_detection.py index 50b4e62909..9c3318960f 100644 --- a/examples/visualizations/pointcloud_detection.py +++ b/examples/visualizations/pointcloud_detection.py @@ -15,7 +15,7 @@ import flash from flash.core.data.utils import download_data -from flash.pointcloud.detection import launch_app, PointCloudObjectDetector, PointCloudObjectDetectorData +from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData, launch_app # 1. Create the DataModule # Dataset Credit: http://www.semantic-kitti.org/ diff --git a/examples/visualizations/pointcloud_segmentation.py b/examples/visualizations/pointcloud_segmentation.py index c2486592b5..8c4657f9f8 100644 --- a/examples/visualizations/pointcloud_segmentation.py +++ b/examples/visualizations/pointcloud_segmentation.py @@ -15,7 +15,7 @@ import flash from flash.core.data.utils import download_data -from flash.pointcloud.segmentation import launch_app, PointCloudSegmentation, PointCloudSegmentationData +from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData, launch_app # 1. Create the DataModule # Dataset Credit: http://www.semantic-kitti.org/ diff --git a/pyproject.toml b/pyproject.toml index e18a6fbac5..38ebddb0b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,115 @@ -[tool.autopep8] -ignore = ["E731"] +[metadata] +license_file = "LICENSE" +description-file = "README.md" + +[build-system] +requires = [ + "setuptools", + "wheel", +] + + +[tool.check-manifest] +ignore = [ + "*.yml", + ".github", + ".github/*" +] + + +[tool.pytest.ini_options] +norecursedirs = [ + ".git", + ".github", + "dist", + "build", + "docs", +] +addopts = [ + "--strict-markers", + "--doctest-modules", + "--color=yes", + "--disable-pytest-warnings", +] +#filterwarnings = [ +# "error::FutureWarning", +#] +xfail_strict = false # todo +junit_duration_report = "call" + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "pass", +] [tool.black] +# https://github.com/psf/black +line-length = 120 +exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)" + +[tool.isort] +known_first_party = [ + "flash", + "examples", + "tests", +] +skip_glob = [] +profile = "black" +line_length = 120 + + +[tool.ruff] line-length = 120 +# Enable Pyflakes `E` and `F` codes by default. +select = [ + "E", "W", # see: https://pypi.org/project/pycodestyle + "F", # see: https://pypi.org/project/pyflakes +# "D", # see: https://pypi.org/project/pydocstyle +# "N", # see: https://pypi.org/project/pep8-naming +] +ignore = [ + "E731", # Do not assign a lambda expression, use a def +] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".eggs", + ".git", + ".mypy_cache", + ".ruff_cache", + "__pypackages__", + "_build", + "build", + "dist", + "docs" +] +ignore-init-module-imports = true + +[tool.ruff.per-file-ignores] +"setup.py" = ["D100", "SIM115"] +"__about__.py" = ["D100"] +"__init__.py" = ["D100"] + +[tool.ruff.pydocstyle] +# Use Google-style docstrings. +convention = "google" + + +[tool.mypy] +files = [ + "src", +] +install_types = true +non_interactive = true +disallow_untyped_defs = true +ignore_missing_imports = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true +allow_redefinition = true +# disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ +disable_error_code = "attr-defined" +# style choices +warn_no_return = false diff --git a/requirements/test.txt b/requirements/test.txt index fd36809355..86953722af 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,8 +1,8 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup -coverage +coverage[toml] codecov>=2.1 -pytest>=5.0, <7.0 +pytest>=6.0, <7.0 pytest-doctestplus>=0.9.0 pytest-rerunfailures>=10.0 pytest-forked diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 6ab4af21b5..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,101 +0,0 @@ -[metadata] -license_file = LICENSE -description-file = README.md - - -[tool:pytest] -norecursedirs = - .git - dist - build -doctest_plus = enabled -addopts = - --durations=0 - --color=yes - - -[coverage:report] -exclude_lines = - pragma: no-cover - pass - if __name__ == .__main__.: - add_model_specific_args - - -[isort] -known_first_party = - flash - examples - tests -line_length = 120 -order_by_type = False -# 3 - Vertical Hanging Indent -multi_line_output = 3 -include_trailing_comma = True - - -[flake8] -max-line-length = 120 -extend-ignore = E203, W503 -ignore = - # Line break occurred after a binary operator - W504 -exclude = - *.egg - build - temp - .git -select = E,W,F -doctests = True -verbose = 2 -# https://pep8.readthedocs.io/en/latest/intro.html#error-codes -format = pylint -# see: https://www.flake8rules.com/ - - -[versioneer] -VCS = git -style = pep440 -versionfile_source = flash/_version.py -versionfile_build = flash/_version.py -tag_prefix = v -parentdir_prefix = - - -# setup.cfg or tox.ini -[check-manifest] -ignore = - *.yml - .github - .github/* - - -[mypy] -# Typing tests is low priority, but enabling type checking on the -# untyped test functions (using `--check-untyped-defs`) is still -# high-value because it helps test the typing. -files = flash, examples, tests -pretty = True -show_error_codes = True -disallow_untyped_defs = True -ignore_missing_imports = True - -# todo: add proper typing to this module... -[mypy-flash.core.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-flash.tabular.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-flash.text.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-flash.image.*] -ignore_errors = True - -# todo -[mypy-tests.*] -ignore_errors = True diff --git a/src/flash/audio/classification/input.py b/src/flash/audio/classification/input.py index 07eb0e1e67..812a924d4d 100644 --- a/src/flash/audio/classification/input.py +++ b/src/flash/audio/classification/input.py @@ -24,11 +24,11 @@ from flash.core.data.utilities.loading import ( AUDIO_EXTENSIONS, IMG_EXTENSIONS, + NP_EXTENSIONS, load_data_frame, load_spectrogram, - NP_EXTENSIONS, ) -from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE +from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files, make_dataset from flash.core.data.utilities.samples import to_samples from flash.core.utilities.imports import requires diff --git a/src/flash/core/adapter.py b/src/flash/core/adapter.py index d9c5328739..559c90ce13 100644 --- a/src/flash/core/adapter.py +++ b/src/flash/core/adapter.py @@ -15,7 +15,7 @@ from typing import Any, Callable, Optional import torch.jit -from torch import nn, Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader, Sampler import flash diff --git a/src/flash/core/data/batch.py b/src/flash/core/data/batch.py index 5b5f1dd24d..ba16ec429c 100644 --- a/src/flash/core/data/batch.py +++ b/src/flash/core/data/batch.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, List -from torch import nn, Tensor +from torch import Tensor, nn from flash.core.data.utilities.classification import _is_list_like diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py index ccc944d142..5c4cae6a99 100644 --- a/src/flash/core/data/data_module.py +++ b/src/flash/core/data/data_module.py @@ -26,10 +26,10 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.io.input import DataKeys, Input, IterableInput from flash.core.data.io.input_transform import ( + InputTransform, create_device_input_transform_processor, create_or_configure_input_transform, create_worker_input_transform_processor, - InputTransform, ) from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX diff --git a/src/flash/core/data/io/classification_input.py b/src/flash/core/data/io/classification_input.py index 6d46e41782..4c5cd60622 100644 --- a/src/flash/core/data/io/classification_input.py +++ b/src/flash/core/data/io/classification_input.py @@ -14,7 +14,7 @@ from typing import Any, List, Optional from flash.core.data.properties import Properties -from flash.core.data.utilities.classification import get_target_formatter, TargetFormatter +from flash.core.data.utilities.classification import TargetFormatter, get_target_formatter class ClassificationInputMixin(Properties): diff --git a/src/flash/core/data/io/input.py b/src/flash/core/data/io/input.py index 0a35195792..f8e3bdc96e 100644 --- a/src/flash/core/data/io/input.py +++ b/src/flash/core/data/io/input.py @@ -14,7 +14,7 @@ import functools import os from enum import Enum -from typing import Any, cast, Dict, Iterable, List, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Sequence, Tuple, Union, cast from pytorch_lightning.utilities.enums import LightningEnum from torch.utils.data import Dataset diff --git a/src/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py index 312c911031..de78e6ee88 100644 --- a/src/flash/core/data/utilities/classification.py +++ b/src/flash/core/data/utilities/classification.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass from functools import reduce -from typing import Any, cast, ClassVar, Dict, List, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast import numpy as np import torch diff --git a/src/flash/core/data/utilities/paths.py b/src/flash/core/data/utilities/paths.py index 472903bae0..7d8850070e 100644 --- a/src/flash/core/data/utilities/paths.py +++ b/src/flash/core/data/utilities/paths.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, cast, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union, cast from pytorch_lightning.utilities import rank_zero_warn diff --git a/src/flash/core/integrations/fiftyone/utils.py b/src/flash/core/integrations/fiftyone/utils.py index 8841560b1e..62a5de5094 100644 --- a/src/flash/core/integrations/fiftyone/utils.py +++ b/src/flash/core/integrations/fiftyone/utils.py @@ -80,7 +80,7 @@ def visualize( dataset = fo.Dataset() if filepaths: - dataset.add_samples([fo.Sample(filepath=f, **{label_field: l}) for f, l in zip(filepaths, labels)]) + dataset.add_samples([fo.Sample(filepath=fp, **{label_field: lb}) for fp, lb in zip(filepaths, labels)]) session = fo.launch_app(dataset, **kwargs) if wait: diff --git a/src/flash/core/integrations/icevision/adapter.py b/src/flash/core/integrations/icevision/adapter.py index bf06ae1b38..5282e46af9 100644 --- a/src/flash/core/integrations/icevision/adapter.py +++ b/src/flash/core/integrations/icevision/adapter.py @@ -21,7 +21,7 @@ import flash from flash.core.adapter import Adapter from flash.core.data.io.input import DataKeys, InputBase -from flash.core.data.io.input_transform import create_worker_input_transform_processor, InputTransform +from flash.core.data.io.input_transform import InputTransform, create_worker_input_transform_processor from flash.core.integrations.icevision.transforms import ( from_icevision_predictions, from_icevision_record, diff --git a/src/flash/core/integrations/icevision/data.py b/src/flash/core/integrations/icevision/data.py index 569e20fb01..caed3a3b65 100644 --- a/src/flash/core/integrations/icevision/data.py +++ b/src/flash/core/integrations/icevision/data.py @@ -17,7 +17,7 @@ import numpy as np from flash.core.data.io.input import DataKeys, Input -from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS +from flash.core.data.utilities.loading import IMG_EXTENSIONS, NP_EXTENSIONS, load_image from flash.core.data.utilities.paths import list_valid_files from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires diff --git a/src/flash/core/integrations/icevision/transforms.py b/src/flash/core/integrations/icevision/transforms.py index 7bc4c083e5..d7201a146d 100644 --- a/src/flash/core/integrations/icevision/transforms.py +++ b/src/flash/core/integrations/icevision/transforms.py @@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np -from torch import nn, Tensor +from torch import Tensor, nn from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform diff --git a/src/flash/core/model.py b/src/flash/core/model.py index 1eda6705a3..e8a06fbf02 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -24,7 +24,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities.enums import LightningEnum -from torch import nn, Tensor +from torch import Tensor, nn from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, Sampler @@ -32,9 +32,9 @@ import flash from flash.core.data.io.input import InputBase, ServeInput from flash.core.data.io.input_transform import ( + InputTransform, create_or_configure_input_transform, create_worker_input_transform_processor, - InputTransform, ) from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform diff --git a/src/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py index 4e727ee7b3..3d96231ff9 100644 --- a/src/flash/core/optimizers/lamb.py +++ b/src/flash/core/optimizers/lamb.py @@ -22,7 +22,6 @@ from typing import Tuple import torch -from torch import nn from torch.optim.optimizer import Optimizer from flash.core.utilities.imports import _CORE_TESTING diff --git a/src/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py index bc4c411cd6..00b0dc8531 100644 --- a/src/flash/core/optimizers/lars.py +++ b/src/flash/core/optimizers/lars.py @@ -19,7 +19,6 @@ # - https://arxiv.org/pdf/1708.03888.pdf # - https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py import torch -from torch import nn from torch.optim.optimizer import Optimizer, required from flash.core.utilities.imports import _CORE_TESTING diff --git a/src/flash/core/optimizers/lr_scheduler.py b/src/flash/core/optimizers/lr_scheduler.py index 497d1d087f..f5ffe7e6b6 100644 --- a/src/flash/core/optimizers/lr_scheduler.py +++ b/src/flash/core/optimizers/lr_scheduler.py @@ -19,8 +19,7 @@ import warnings from typing import List -from torch import nn -from torch.optim import Adam, Optimizer +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from flash.core.utilities.imports import _CORE_TESTING diff --git a/src/flash/core/optimizers/schedulers.py b/src/flash/core/optimizers/schedulers.py index a976a4129f..aff77c9ce6 100644 --- a/src/flash/core/optimizers/schedulers.py +++ b/src/flash/core/optimizers/schedulers.py @@ -3,13 +3,13 @@ from torch.optim import lr_scheduler from torch.optim.lr_scheduler import ( - _LRScheduler, CosineAnnealingLR, CosineAnnealingWarmRestarts, CyclicLR, MultiStepLR, ReduceLROnPlateau, StepLR, + _LRScheduler, ) from flash.core.registry import FlashRegistry diff --git a/src/flash/core/serve/core.py b/src/flash/core/serve/core.py index 986cb63ba0..06447f9aa9 100644 --- a/src/flash/core/serve/core.py +++ b/src/flash/core/serve/core.py @@ -11,7 +11,7 @@ from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires if _PYDANTIC_AVAILABLE: - from pydantic import FilePath, HttpUrl, parse_obj_as, ValidationError + from pydantic import FilePath, HttpUrl, ValidationError, parse_obj_as else: FilePath, HttpUrl, parse_obj_as, ValidationError = None, None, None, None diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index cade83b6df..e5e6f5bd3d 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -4,7 +4,6 @@ from flash.core.serve.dag.task import flatten, get, get_dependencies, ishashable, istask, reverse_dict, subs, toposort from flash.core.serve.dag.utils import key_split -from flash.core.serve.dag.utils_test import add, inc, mul from flash.core.utilities.imports import _SERVE_TESTING # Skip doctests if requirements aren't available diff --git a/src/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py index c7ccad6d68..cccd9a6107 100644 --- a/src/flash/core/serve/dag/order.py +++ b/src/flash/core/serve/dag/order.py @@ -79,8 +79,7 @@ from collections import defaultdict from math import log -from flash.core.serve.dag.task import get_dependencies, get_deps, getcycle, reverse_dict -from flash.core.serve.dag.utils_test import add, inc +from flash.core.serve.dag.task import get_dependencies, getcycle, reverse_dict from flash.core.utilities.imports import _SERVE_TESTING # Skip doctests if requirements aren't available diff --git a/src/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py index b9d6696394..6f7c8bebde 100644 --- a/src/flash/core/serve/dag/task.py +++ b/src/flash/core/serve/dag/task.py @@ -1,7 +1,6 @@ from collections import defaultdict from typing import List, Sequence -from flash.core.serve.dag.utils_test import add, inc from flash.core.utilities.imports import _SERVE_TESTING # Skip doctests if requirements aren't available diff --git a/src/flash/core/serve/decorators.py b/src/flash/core/serve/decorators.py index ab9de9f682..0c1ecc9708 100644 --- a/src/flash/core/serve/decorators.py +++ b/src/flash/core/serve/decorators.py @@ -5,7 +5,7 @@ from typing import Dict, List, Sequence, Tuple, Union from uuid import uuid4 -from flash.core.serve.core import Connection, make_param_dict, make_parameter_container, ParameterContainer, Servable +from flash.core.serve.core import Connection, ParameterContainer, Servable, make_param_dict, make_parameter_container from flash.core.serve.types.base import BaseType from flash.core.serve.utils import fn_outputs_to_keyed_map from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING diff --git a/src/flash/core/serve/execution.py b/src/flash/core/serve/execution.py index 1546ff76d9..1039f29a30 100644 --- a/src/flash/core/serve/execution.py +++ b/src/flash/core/serve/execution.py @@ -1,7 +1,7 @@ from collections import defaultdict from dataclasses import dataclass from operator import attrgetter -from typing import Dict, List, Set, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Set, Tuple from flash.core.serve.dag.optimization import cull, functions_of, inline_functions from flash.core.serve.dag.rewrite import RewriteRule, RuleSet diff --git a/src/flash/core/serve/flash_components.py b/src/flash/core/serve/flash_components.py index 37109d4855..40485eca90 100644 --- a/src/flash/core/serve/flash_components.py +++ b/src/flash/core/serve/flash_components.py @@ -8,7 +8,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import DataKeys from flash.core.data.io.output_transform import OutputTransform -from flash.core.serve import expose, ModelComponent +from flash.core.serve import ModelComponent, expose from flash.core.serve.types.base import BaseType from flash.core.trainer import Trainer from flash.core.utilities.stages import RunningStage diff --git a/src/flash/core/serve/interfaces/http.py b/src/flash/core/serve/interfaces/http.py index 861ad32937..915b3d5f17 100644 --- a/src/flash/core/serve/interfaces/http.py +++ b/src/flash/core/serve/interfaces/http.py @@ -2,17 +2,17 @@ import uuid from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from flash.core.serve.dag.task import get from flash.core.serve.dag.visualize import visualize from flash.core.serve.execution import ( - build_composition, - component_dag_content, ComponentJSON, - merged_dag_content, MergedJSON, TaskComposition, + build_composition, + component_dag_content, + merged_dag_content, ) from flash.core.serve.interfaces.models import Alive, EndpointProtocol from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _FASTAPI_AVAILABLE diff --git a/src/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py index bf2ef4abf4..75f34dd98f 100644 --- a/src/flash/core/utilities/flash_cli.py +++ b/src/flash/core/utilities/flash_cli.py @@ -28,10 +28,10 @@ import flash from flash.core.data.data_module import DataModule from flash.core.utilities.lightning_cli import ( - class_from_function, LightningArgumentParser, LightningCLI, SaveConfigCallback, + class_from_function, ) from flash.core.utilities.stability import beta diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py index b9d8224115..1d89796a1c 100644 --- a/src/flash/core/utilities/lightning_cli.py +++ b/src/flash/core/utilities/lightning_cli.py @@ -6,7 +6,7 @@ from argparse import Namespace from functools import wraps from types import MethodType -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast import torch from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode diff --git a/src/flash/graph/classification/model.py b/src/flash/graph/classification/model.py index 1dc1c21e48..41556e961a 100644 --- a/src/flash/graph/classification/model.py +++ b/src/flash/graph/classification/model.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from torch import nn, Tensor -from torch.nn import functional as F +from torch import Tensor, nn from torch.nn import Linear +from torch.nn import functional as F from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys diff --git a/src/flash/graph/embedding/model.py b/src/flash/graph/embedding/model.py index 33054eed1c..3ddd53a38a 100644 --- a/src/flash/graph/embedding/model.py +++ b/src/flash/graph/embedding/model.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, IO, Optional, Union +from typing import IO, Any, Callable, Dict, Optional, Union import torch -from torch import nn, Tensor +from torch import Tensor, nn from flash.core.data.io.input import DataKeys from flash.core.model import Task -from flash.graph.classification.model import GraphClassifier, POOLING_FUNCTIONS +from flash.graph.classification.model import POOLING_FUNCTIONS, GraphClassifier from flash.graph.collate import _pyg_collate diff --git a/src/flash/image/classification/adapters.py b/src/flash/image/classification/adapters.py index b5949a369f..cc1cd0316f 100644 --- a/src/flash/image/classification/adapters.py +++ b/src/flash/image/classification/adapters.py @@ -22,7 +22,7 @@ from lightning_utilities.core.rank_zero import WarningCache from pytorch_lightning import LightningModule from pytorch_lightning.trainer.states import TrainerFn -from torch import nn, Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader, IterableDataset, Sampler import flash diff --git a/src/flash/image/classification/data.py b/src/flash/image/classification/data.py index df5ec1f89f..5a90d78dfe 100644 --- a/src/flash/image/classification/data.py +++ b/src/flash/image/classification/data.py @@ -25,7 +25,7 @@ from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioImageClassificationInput +from flash.core.integrations.labelstudio.input import LabelStudioImageClassificationInput, _parse_labelstudio_arguments from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, diff --git a/src/flash/image/classification/input.py b/src/flash/image/classification/input.py index 71731945d8..d770780f0f 100644 --- a/src/flash/image/classification/input.py +++ b/src/flash/image/classification/input.py @@ -21,17 +21,17 @@ from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.data_frame import resolve_files, resolve_targets from flash.core.data.utilities.loading import load_data_frame -from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE +from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files, make_dataset from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires from flash.image.data import ( + IMG_EXTENSIONS, + NP_EXTENSIONS, ImageFilesInput, ImageInput, ImageNumpyInput, ImageTensorInput, - IMG_EXTENSIONS, - NP_EXTENSIONS, ) if _FIFTYONE_AVAILABLE: diff --git a/src/flash/image/classification/integrations/baal/data.py b/src/flash/image/classification/integrations/baal/data.py index 6b60a4280a..7fe1a5fe5a 100644 --- a/src/flash/image/classification/integrations/baal/data.py +++ b/src/flash/image/classification/integrations/baal/data.py @@ -27,7 +27,7 @@ if _BAAL_AVAILABLE: from baal.active.dataset import ActiveLearningDataset - from baal.active.heuristics import AbstractHeuristic, BALD + from baal.active.heuristics import BALD, AbstractHeuristic else: class AbstractHeuristic: diff --git a/src/flash/image/data.py b/src/flash/image/data.py index 75c0e8b9c9..c224707d41 100644 --- a/src/flash/image/data.py +++ b/src/flash/image/data.py @@ -20,8 +20,8 @@ import flash from flash.core.data.io.input import DataKeys, Input, ServeInput -from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS -from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE +from flash.core.data.utilities.loading import IMG_EXTENSIONS, NP_EXTENSIONS, load_image +from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files from flash.core.data.utilities.samples import to_samples from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires diff --git a/src/flash/image/detection/input.py b/src/flash/image/detection/input.py index d8a0d60d5b..97daee9f16 100644 --- a/src/flash/image/detection/input.py +++ b/src/flash/image/detection/input.py @@ -16,18 +16,18 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys from flash.core.data.utilities.classification import TargetFormatter -from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE +from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.integrations.icevision.data import IceVisionInput from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires from flash.image.data import ( + IMG_EXTENSIONS, + NP_EXTENSIONS, ImageFilesInput, ImageInput, ImageNumpyInput, ImageTensorInput, - IMG_EXTENSIONS, - NP_EXTENSIONS, ) if _FIFTYONE_AVAILABLE: diff --git a/src/flash/image/embedding/losses/vissl_losses.py b/src/flash/image/embedding/losses/vissl_losses.py index ddfd6e05de..2b1a9eb404 100644 --- a/src/flash/image/embedding/losses/vissl_losses.py +++ b/src/flash/image/embedding/losses/vissl_losses.py @@ -21,7 +21,7 @@ if _VISSL_AVAILABLE: import vissl.losses # noqa: F401 from classy_vision.generic.distributed_util import set_cpu_device - from classy_vision.losses import ClassyLoss, LOSS_REGISTRY + from classy_vision.losses import LOSS_REGISTRY, ClassyLoss from vissl.config.attr_dict import AttrDict else: AttrDict = object diff --git a/src/flash/image/segmentation/input.py b/src/flash/image/segmentation/input.py index 7d8455ae63..4662e1c986 100644 --- a/src/flash/image/segmentation/input.py +++ b/src/flash/image/segmentation/input.py @@ -17,8 +17,8 @@ import numpy as np from flash.core.data.io.input import DataKeys, Input -from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS -from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE +from flash.core.data.utilities.loading import IMG_EXTENSIONS, NP_EXTENSIONS, load_image +from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import diff --git a/src/flash/image/segmentation/model.py b/src/flash/image/segmentation/model.py index 4d7d796c9c..f07cb11054 100644 --- a/src/flash/image/segmentation/model.py +++ b/src/flash/image/segmentation/model.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type, Union -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import functional as F from flash.core.classification import ClassificationTask diff --git a/src/flash/image/style_transfer/model.py b/src/flash/image/style_transfer/model.py index 5505d2ff1f..84eb62a04f 100644 --- a/src/flash/image/style_transfer/model.py +++ b/src/flash/image/style_transfer/model.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, cast, List, NoReturn, Optional, Sequence, Tuple, Union +from typing import Any, List, NoReturn, Optional, Sequence, Tuple, Union, cast -from torch import nn, Tensor +from torch import Tensor, nn from flash.core.data.io.input import DataKeys from flash.core.model import Task diff --git a/src/flash/pointcloud/detection/model.py b/src/flash/pointcloud/detection/model.py index 1ceea10ad8..a74ba701aa 100644 --- a/src/flash/pointcloud/detection/model.py +++ b/src/flash/pointcloud/detection/model.py @@ -14,7 +14,7 @@ import sys from typing import Any, Dict, Optional, Tuple, Union -from torch import nn, Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader, Sampler import flash diff --git a/src/flash/pointcloud/detection/open3d_ml/input.py b/src/flash/pointcloud/detection/open3d_ml/input.py index 616fe1d211..397d0da49d 100644 --- a/src/flash/pointcloud/detection/open3d_ml/input.py +++ b/src/flash/pointcloud/detection/open3d_ml/input.py @@ -21,7 +21,7 @@ from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE if _POINTCLOUD_AVAILABLE: - from open3d._ml3d.datasets.kitti import DataProcessing, KITTI + from open3d._ml3d.datasets.kitti import KITTI, DataProcessing class PointCloudObjectDetectionDataFormat(BaseDataFormat): diff --git a/src/flash/pointcloud/segmentation/model.py b/src/flash/pointcloud/segmentation/model.py index 7fbcbc2460..1cbb2c0752 100644 --- a/src/flash/pointcloud/segmentation/model.py +++ b/src/flash/pointcloud/segmentation/model.py @@ -14,7 +14,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import functional as F from torch.utils.data import DataLoader, Sampler diff --git a/src/flash/template/classification/model.py b/src/flash/template/classification/model.py index e5d812f15f..d8de0b2a15 100644 --- a/src/flash/template/classification/model.py +++ b/src/flash/template/classification/model.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union -from torch import nn, Tensor +from torch import Tensor, nn from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys diff --git a/src/flash/text/classification/data.py b/src/flash/text/classification/data.py index 09a7b59129..38fb909f44 100644 --- a/src/flash/text/classification/data.py +++ b/src/flash/text/classification/data.py @@ -20,7 +20,7 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioTextClassificationInput +from flash.core.integrations.labelstudio.input import LabelStudioTextClassificationInput, _parse_labelstudio_arguments from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING from flash.core.utilities.stages import RunningStage from flash.text.classification.input import ( diff --git a/src/flash/video/classification/data.py b/src/flash/video/classification/data.py index a49217ad2b..bc24b85988 100644 --- a/src/flash/video/classification/data.py +++ b/src/flash/video/classification/data.py @@ -22,7 +22,7 @@ from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioVideoClassificationInput +from flash.core.integrations.labelstudio.input import LabelStudioVideoClassificationInput, _parse_labelstudio_arguments from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, diff --git a/src/flash/video/classification/input.py b/src/flash/video/classification/input.py index ca66dcb531..5c2fb58d36 100644 --- a/src/flash/video/classification/input.py +++ b/src/flash/video/classification/input.py @@ -20,10 +20,10 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input, IterableInput -from flash.core.data.utilities.classification import _is_list_like, MultiBinaryTargetFormatter, TargetFormatter +from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter, _is_list_like from flash.core.data.utilities.data_frame import resolve_files, resolve_targets from flash.core.data.utilities.loading import load_data_frame -from flash.core.data.utilities.paths import list_valid_files, make_dataset, PATH_TYPE +from flash.core.data.utilities.paths import PATH_TYPE, list_valid_files, make_dataset from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import, requires diff --git a/tests/core/data/utilities/test_classification.py b/tests/core/data/utilities/test_classification.py index 453f570dea..709a9c510f 100644 --- a/tests/core/data/utilities/test_classification.py +++ b/tests/core/data/utilities/test_classification.py @@ -20,7 +20,6 @@ from flash.core.data.utilities.classification import ( CommaDelimitedMultiLabelTargetFormatter, - get_target_formatter, MultiBinaryTargetFormatter, MultiLabelTargetFormatter, MultiNumericTargetFormatter, @@ -29,6 +28,7 @@ SingleLabelTargetFormatter, SingleNumericTargetFormatter, SpaceDelimitedTargetFormatter, + get_target_formatter, ) from flash.core.utilities.imports import _CORE_TESTING diff --git a/tests/core/data/utilities/test_loading.py b/tests/core/data/utilities/test_loading.py index 5717646684..15bc0ec6d7 100644 --- a/tests/core/data/utilities/test_loading.py +++ b/tests/core/data/utilities/test_loading.py @@ -20,12 +20,12 @@ AUDIO_EXTENSIONS, CSV_EXTENSIONS, IMG_EXTENSIONS, + NP_EXTENSIONS, + TSV_EXTENSIONS, load_audio, load_data_frame, load_image, load_spectrogram, - NP_EXTENSIONS, - TSV_EXTENSIONS, ) from flash.core.utilities.imports import ( _AUDIO_AVAILABLE, diff --git a/tests/core/data/utilities/test_paths.py b/tests/core/data/utilities/test_paths.py index a7397a7cf2..ebfc649b05 100644 --- a/tests/core/data/utilities/test_paths.py +++ b/tests/core/data/utilities/test_paths.py @@ -20,7 +20,7 @@ from numpy import random from flash.core.data.utilities.loading import AUDIO_EXTENSIONS, IMG_EXTENSIONS, NP_EXTENSIONS -from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE +from flash.core.data.utilities.paths import PATH_TYPE, filter_valid_files def _make_mock_dir(root, mock_files: List) -> List[PATH_TYPE]: diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py index 99171aa70d..cff87fed25 100644 --- a/tests/core/integrations/labelstudio/test_labelstudio.py +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -2,10 +2,10 @@ from flash.core.data.utils import download_data from flash.core.integrations.labelstudio.input import ( - _load_json_data, LabelStudioImageClassificationInput, LabelStudioInput, LabelStudioTextClassificationInput, + _load_json_data, ) from flash.core.integrations.labelstudio.visualizer import launch_app from flash.core.utilities.imports import _CORE_TESTING, _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING diff --git a/tests/core/serve/models.py b/tests/core/serve/models.py index 89dad5d24c..8b907e932d 100644 --- a/tests/core/serve/models.py +++ b/tests/core/serve/models.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from flash.core.serve import expose, ModelComponent +from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Image, Label, Number, Repeated from flash.core.utilities.imports import _TORCHVISION_AVAILABLE diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py index 067665d273..ab971a4724 100644 --- a/tests/core/serve/test_components.py +++ b/tests/core/serve/test_components.py @@ -123,7 +123,7 @@ def test_component_parameters(lightning_squeezenet1_1_obj): @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_expose_inputs(): - from flash.core.serve import expose, ModelComponent + from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number lr = LightningSqueezenet() @@ -199,7 +199,7 @@ class FakeParam: @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_name(lightning_squeezenet1_1_obj): - from flash.core.serve import expose, ModelComponent + from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number with pytest.raises(SyntaxError): @@ -216,7 +216,7 @@ def predict(param): @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_config_args(lightning_squeezenet1_1_obj): - from flash.core.serve import expose, ModelComponent + from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number class SomeComponent(ModelComponent): @@ -243,7 +243,7 @@ def predict(self, param): @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_invalid_model_args(lightning_squeezenet1_1_obj): - from flash.core.serve import expose, ModelComponent + from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number class SomeComponent(ModelComponent): diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py index 14ace1cf99..366b3cdbbb 100644 --- a/tests/core/serve/test_dag/test_optimization.py +++ b/tests/core/serve/test_dag/test_optimization.py @@ -5,13 +5,13 @@ import pytest from flash.core.serve.dag.optimization import ( + SubgraphCallable, cull, functions_of, fuse, fuse_linear, inline, inline_functions, - SubgraphCallable, ) from flash.core.serve.dag.task import get, get_dependencies from flash.core.serve.dag.utils import apply, partial_by_order diff --git a/tests/core/serve/test_dag/test_rewrite.py b/tests/core/serve/test_dag/test_rewrite.py index 155c2da3df..2bfbae013a 100644 --- a/tests/core/serve/test_dag/test_rewrite.py +++ b/tests/core/serve/test_dag/test_rewrite.py @@ -1,6 +1,6 @@ import pytest -from flash.core.serve.dag.rewrite import args, head, RewriteRule, RuleSet, Traverser, VAR +from flash.core.serve.dag.rewrite import VAR, RewriteRule, RuleSet, Traverser, args, head from flash.core.utilities.imports import _SERVE_TESTING diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py index c491a4b209..ab8c29fa08 100644 --- a/tests/core/serve/test_gridbase_validations.py +++ b/tests/core/serve/test_gridbase_validations.py @@ -1,6 +1,6 @@ import pytest -from flash.core.serve import expose, ModelComponent +from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index 51a79b71ae..8d5063fa14 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -414,7 +414,7 @@ def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1 @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") def test_composition_from_url_torchscript_servable(tmp_path): - from flash.core.serve import expose, ModelComponent, Servable + from flash.core.serve import ModelComponent, Servable, expose from flash.core.serve.types import Number """ diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index f8beb1c4b1..662eddc64a 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -19,9 +19,8 @@ import torch from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from torch import Tensor -from torch.nn import Flatten +from torch.nn import Flatten, Linear, LogSoftmax, Module from torch.nn import functional as F -from torch.nn import Linear, LogSoftmax, Module from torch.utils.data import DataLoader import flash diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d1ce3083a4..c60cf1c9ad 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -25,7 +25,7 @@ import torch from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks import Callback -from torch import nn, Tensor +from torch import Tensor, nn from torch.nn import functional as F from torch.utils.data import DataLoader from torchmetrics import Accuracy diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index 8e14e3562b..b1eee81e7e 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -28,10 +28,10 @@ _TORCHVISION_AVAILABLE, ) from flash.core.utilities.lightning_cli import ( - instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback, + instantiate_class, ) from tests.helpers.boring_model import BoringDataModule, BoringModel From 209b038bdc9969d6dc401f4d8ec91e4d02d58ade Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 24 Mar 2023 12:56:13 +0100 Subject: [PATCH 16/39] update readme --- README.md | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index da21a904a2..da3188e95e 100644 --- a/README.md +++ b/README.md @@ -38,18 +38,6 @@ In a nutshell, Flash is the production grade research framework you always dreamed of but didn't have time to build.
- -## News - -- Sept 30: [Lightning Flash now supports Meta-Learning](https://devblog.pytorchlightning.ai/lightning-flash-now-supports-meta-learning-7c0ac8b1cde7) -- Sept 9: [Lightning Flash 0.5](https://devblog.pytorchlightning.ai/flash-0-5-your-pytorch-ai-factory-81b172ff0d76) -- Jul 12: Flash Task-a-thon community sprint with 25+ community members -- Jul 1: [Lightning Flash 0.4](https://devblog.pytorchlightning.ai/lightning-flash-0-4-flash-serve-fiftyone-multi-label-text-classification-and-jit-support-97428276c06f) -- Jun 22: [Ushering in the New Age of Video Understanding with PyTorch](https://medium.com/pytorch/ushering-in-the-new-age-of-video-understanding-with-pytorch-1d85078e8015) -- May 24: [Lightning Flash 0.3](https://devblog.pytorchlightning.ai/lightning-flash-0-3-new-tasks-visualization-tools-data-pipeline-and-flash-registry-api-1e236ba9530) -- May 20: [Video Understanding with PyTorch](https://towardsdatascience.com/video-understanding-made-simple-with-pytorch-video-and-lightning-flash-c7d65583c37e) -- Feb 2: [Read our launch blogpost](https://pytorch-lightning.medium.com/introducing-lightning-flash-the-fastest-way-to-get-started-with-deep-learning-202f196b3b98) - ## Getting Started From PyPI: @@ -131,6 +119,8 @@ model.serve() or make predictions from raw data directly. ```py +from flash import Trainer + trainer = Trainer(strategy='ddp', accelerator="gpu", gpus=2) dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB") predictions = trainer.predict(model, dm) @@ -145,6 +135,8 @@ Check out this [example](https://github.com/Lightning-AI/lightning-flash/blob/ma This is particularly useful if you use this model in production and want to make sure the model adapts quickly to its new environment with minimal labelled data. ```py +from flash.image import ImageClassifier + model = ImageClassifier( backbone="resnet18", optimizer=torch.optim.Adam, @@ -177,6 +169,8 @@ In detail, the following methods are currently implemented: With Flash, swapping among 40+ optimizers and 15+ schedulers recipes are simple. Find the list of available optimizers, schedulers as follows: ```py +from flash.image import ImageClassifier + ImageClassifier.available_optimizers() # ['A2GradExp', ..., 'Yogi'] @@ -187,7 +181,9 @@ ImageClassifier.available_schedulers() Once you've chosen, create the model: ```py -#### The optimizer of choice can be passed as a +#### The optimizer of choice can be passed as +from flash.image import ImageClassifier + # - String value model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None) @@ -211,6 +207,8 @@ model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr You can also register you own custom scheduler recipes beforeahand and use them shown as above: ```py +from flash.image import ImageClassifier + @ImageClassifier.lr_schedulers_registry def my_steplr_recipe(optimizer): return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10) @@ -279,13 +277,13 @@ using the [`Lightning CLI`](https://pytorch-lightning.readthedocs.io/en/stable/c To get started and view the available tasks, run: -```py +```bash flash --help ``` For example, to train an image classifier for 10 epochs with a `resnet50` backbone on 2 GPUs using your own data, you can do: -```py +```bash flash image_classification --trainer.max_epochs 10 --trainer.gpus 2 --model.backbone resnet50 from_folders --train_folder {PATH_TO_DATA} ``` From 140fdd6dba9c69431b588af9b232725b23fa11ad Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 29 Mar 2023 15:54:24 +0200 Subject: [PATCH 17/39] ci: refactor testing per topics (#1538) --- .azure/gpu-example-tests.yml | 2 +- ...ing-template.yml => template-examples.yml} | 0 .github/workflows/ci-testing.yml | 64 +++++++------ examples/video_classification.py | 4 +- requirements.txt | 18 ++-- requirements/datatype_image.txt | 2 +- requirements/datatype_image_extras.txt | 4 +- requirements/datatype_image_extras_baal.txt | 1 + src/flash/audio/classification/data.py | 4 +- .../audio/speech_recognition/backbone.py | 4 +- src/flash/audio/speech_recognition/collate.py | 4 +- src/flash/audio/speech_recognition/data.py | 4 +- src/flash/audio/speech_recognition/input.py | 4 +- src/flash/audio/speech_recognition/model.py | 4 +- .../speech_recognition/output_transform.py | 4 +- src/flash/core/data/data_module.py | 4 +- .../core/data/utilities/classification.py | 4 +- src/flash/core/data/utilities/loading.py | 4 +- src/flash/core/data/utils.py | 4 +- .../core/integrations/icevision/transforms.py | 4 +- src/flash/core/model.py | 4 +- src/flash/core/optimizers/lamb.py | 4 +- src/flash/core/optimizers/lars.py | 4 +- src/flash/core/optimizers/lr_scheduler.py | 4 +- src/flash/core/serve/component.py | 4 +- src/flash/core/serve/dag/optimization.py | 4 +- src/flash/core/serve/dag/order.py | 4 +- src/flash/core/serve/dag/rewrite.py | 4 +- src/flash/core/serve/dag/task.py | 4 +- src/flash/core/serve/dag/utils.py | 4 +- src/flash/core/serve/decorators.py | 4 +- src/flash/core/serve/interfaces/models.py | 4 +- src/flash/core/utilities/imports.py | 67 +++++--------- src/flash/core/utilities/stability.py | 4 +- src/flash/graph/backbones.py | 4 +- src/flash/graph/classification/data.py | 4 +- src/flash/graph/classification/input.py | 4 +- .../graph/classification/input_transform.py | 4 +- src/flash/graph/classification/model.py | 4 +- src/flash/graph/collate.py | 4 +- src/flash/image/classification/data.py | 8 +- src/flash/image/detection/data.py | 4 +- src/flash/image/instance_segmentation/data.py | 4 +- src/flash/image/keypoint_detection/data.py | 4 +- src/flash/image/segmentation/data.py | 7 +- src/flash/image/style_transfer/data.py | 4 +- src/flash/image/style_transfer/model.py | 4 +- src/flash/pointcloud/detection/datasets.py | 4 +- .../pointcloud/detection/open3d_ml/app.py | 4 +- .../detection/open3d_ml/backbones.py | 6 +- .../pointcloud/detection/open3d_ml/input.py | 4 +- src/flash/pointcloud/segmentation/datasets.py | 4 +- src/flash/pointcloud/segmentation/model.py | 4 +- .../pointcloud/segmentation/open3d_ml/app.py | 4 +- .../segmentation/open3d_ml/backbones.py | 4 +- .../open3d_ml/sequences_dataset.py | 4 +- src/flash/tabular/classification/data.py | 4 +- src/flash/tabular/classification/model.py | 4 +- src/flash/tabular/forecasting/data.py | 4 +- src/flash/tabular/regression/data.py | 4 +- src/flash/tabular/regression/model.py | 4 +- src/flash/text/classification/data.py | 6 +- src/flash/text/classification/input.py | 4 +- src/flash/text/embedding/backbones.py | 4 +- src/flash/text/embedding/model.py | 4 +- src/flash/text/question_answering/data.py | 4 +- src/flash/text/question_answering/input.py | 4 +- src/flash/text/question_answering/model.py | 4 +- src/flash/text/seq2seq/core/input.py | 4 +- src/flash/text/seq2seq/core/model.py | 4 +- src/flash/text/seq2seq/summarization/data.py | 6 +- src/flash/text/seq2seq/translation/data.py | 6 +- src/flash/video/classification/data.py | 13 +-- src/flash/video/classification/utils.py | 4 +- tests/audio/classification/test_data.py | 20 ++--- tests/audio/classification/test_model.py | 6 +- tests/audio/speech_recognition/test_data.py | 10 +-- .../test_data_model_integration.py | 6 +- tests/audio/speech_recognition/test_model.py | 10 +-- tests/conftest.py | 4 +- tests/core/data/io/test_input.py | 8 +- tests/core/data/io/test_output.py | 4 +- tests/core/data/io/test_output_transform.py | 4 +- tests/core/data/test_base_viz.py | 4 +- tests/core/data/test_batch.py | 6 +- tests/core/data/test_callback.py | 4 +- tests/core/data/test_callbacks.py | 4 +- tests/core/data/test_data_module.py | 12 +-- tests/core/data/test_input_transform.py | 6 +- tests/core/data/test_properties.py | 4 +- tests/core/data/test_splits.py | 8 +- tests/core/data/test_transforms.py | 6 +- .../data/utilities/test_classification.py | 10 +-- tests/core/data/utilities/test_loading.py | 25 +++--- .../labelstudio/test_labelstudio.py | 25 +++--- tests/core/optimizers/test_lr_scheduler.py | 4 +- tests/core/optimizers/test_optimizers.py | 6 +- .../serve/test_compat/test_cached_property.py | 6 +- tests/core/serve/test_components.py | 26 +++--- tests/core/serve/test_composition.py | 20 ++--- .../core/serve/test_dag/test_optimization.py | 44 ++++----- tests/core/serve/test_dag/test_order.py | 56 ++++++------ tests/core/serve/test_dag/test_rewrite.py | 18 ++-- tests/core/serve/test_dag/test_task.py | 38 ++++---- tests/core/serve/test_dag/test_utils.py | 12 +-- tests/core/serve/test_gridbase_validations.py | 6 +- tests/core/serve/test_integration.py | 20 ++--- tests/core/serve/test_types/test_bbox.py | 6 +- tests/core/serve/test_types/test_image.py | 4 +- tests/core/serve/test_types/test_label.py | 10 +-- tests/core/serve/test_types/test_number.py | 6 +- tests/core/serve/test_types/test_repeated.py | 12 +-- tests/core/serve/test_types/test_table.py | 12 +-- tests/core/serve/test_types/test_text.py | 6 +- tests/core/test_classification.py | 8 +- tests/core/test_data.py | 6 +- tests/core/test_finetuning.py | 8 +- tests/core/test_model.py | 60 ++++++------- tests/core/test_registry.py | 10 +-- tests/core/test_trainer.py | 16 ++-- tests/core/test_utils.py | 8 +- tests/core/utilities/test_embedder.py | 13 +-- tests/core/utilities/test_lightning_cli.py | 50 +++++------ tests/core/utilities/test_stability.py | 4 +- tests/examples/test_integrations.py | 17 ++-- tests/examples/test_scripts.py | 90 ++++++++++--------- tests/examples/utils.py | 6 +- tests/graph/classification/test_data.py | 6 +- tests/graph/classification/test_model.py | 10 +-- tests/graph/embedding/test_model.py | 14 +-- .../classification/test_active_learning.py | 10 ++- tests/image/classification/test_data.py | 37 ++++---- .../test_data_model_integration.py | 6 +- tests/image/classification/test_model.py | 16 ++-- .../test_training_strategies.py | 4 +- tests/image/detection/test_data.py | 9 +- .../detection/test_data_model_integration.py | 5 +- tests/image/detection/test_model.py | 22 ++--- tests/image/detection/test_output.py | 4 +- tests/image/embedding/test_model.py | 14 +-- .../__init__.py | 0 .../test_data.py | 8 +- .../test_model.py | 12 +-- tests/image/keypoint_detection/test_data.py | 6 +- tests/image/keypoint_detection/test_model.py | 12 +-- .../__init__.py | 0 .../test_backbones.py | 0 .../test_data.py | 19 ++-- .../test_heads.py | 4 +- .../test_model.py | 20 ++--- .../test_output.py | 10 +-- tests/image/style_transfer/test_model.py | 8 +- tests/image/test_backbones.py | 16 ++-- tests/pointcloud/detection/test_data.py | 6 +- tests/pointcloud/detection/test_model.py | 8 +- tests/pointcloud/segmentation/test_data.py | 4 +- .../pointcloud/segmentation/test_datasets.py | 4 +- tests/pointcloud/segmentation/test_model.py | 10 +-- tests/serve/__init__.py | 1 + tests/tabular/classification/test_data.py | 24 ++--- .../test_data_model_integration.py | 6 +- tests/tabular/classification/test_model.py | 8 +- tests/tabular/forecasting/test_data.py | 8 +- tests/tabular/forecasting/test_model.py | 8 +- .../regression/test_data_model_integration.py | 10 +-- tests/tabular/regression/test_model.py | 8 +- tests/template/classification/test_data.py | 4 +- tests/template/classification/test_model.py | 18 ++-- tests/text/classification/test_data.py | 34 +++---- .../test_data_model_integration.py | 4 +- tests/text/classification/test_model.py | 8 +- tests/text/embedding/test_model.py | 8 +- tests/text/question_answering/test_data.py | 12 +-- tests/text/question_answering/test_model.py | 8 +- tests/text/seq2seq/summarization/test_data.py | 10 +-- .../text/seq2seq/summarization/test_model.py | 8 +- tests/text/seq2seq/translation/test_data.py | 10 +-- tests/text/seq2seq/translation/test_model.py | 8 +- tests/video/classification/test_data.py | 10 +-- tests/video/classification/test_model.py | 22 ++--- 180 files changed, 893 insertions(+), 888 deletions(-) rename .azure/{testing-template.yml => template-examples.yml} (100%) rename tests/image/{instance_segmentation => instance_segm}/__init__.py (100%) rename tests/image/{instance_segmentation => instance_segm}/test_data.py (88%) rename tests/image/{instance_segmentation => instance_segm}/test_model.py (90%) rename tests/image/{segmentation => semantic_segm}/__init__.py (100%) rename tests/image/{segmentation => semantic_segm}/test_backbones.py (100%) rename tests/image/{segmentation => semantic_segm}/test_data.py (94%) rename tests/image/{segmentation => semantic_segm}/test_heads.py (93%) rename tests/image/{segmentation => semantic_segm}/test_model.py (80%) rename tests/image/{segmentation => semantic_segm}/test_output.py (86%) create mode 100644 tests/serve/__init__.py diff --git a/.azure/gpu-example-tests.yml b/.azure/gpu-example-tests.yml index ea322e8625..3820f48254 100644 --- a/.azure/gpu-example-tests.yml +++ b/.azure/gpu-example-tests.yml @@ -8,7 +8,7 @@ pr: drafts: 'true' jobs: -- template: testing-template.yml +- template: template-examples.yml parameters: domains: - "image" diff --git a/.azure/testing-template.yml b/.azure/template-examples.yml similarity index 100% rename from .azure/testing-template.yml rename to .azure/template-examples.yml diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 5f5dfe999b..045bc72f37 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -27,25 +27,25 @@ jobs: os: [ubuntu-20.04, macOS-12, windows-2022] python-version: [3.7, 3.9] requires: ['oldest', 'latest'] - topic: [['core']] - release: [ 'stable' ] + topic: ['core'] + extra: [[]] exclude: # Skip if torch<1.8 and py3.9 on Linux: https://github.com/pytorch/pytorch/issues/50014 - - { os: ubuntu-20.04, python-version: 3.9, requires: 'oldest' } - - { os: ubuntu-20.04, python-version: 3.9, requires: 'latest' } + - { python-version: 3.9, requires: 'oldest' } + - { os: macOS-12, requires: 'oldest' } + - { os: windows-2022, requires: 'oldest' } include: - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'pre', topic: [ 'core' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'image' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'image','image_extras' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'image','image_extras_baal' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'video' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'video','video_extras' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'tabular' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'text' ] } - - { os: 'ubuntu-20.04', python-version: 3.8, release: 'stable', topic: [ 'pointcloud' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'serve' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'graph' ] } - - { os: 'ubuntu-20.04', python-version: 3.9, release: 'stable', topic: [ 'audio' ] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'core', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extras']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extras_baal']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'video', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'video', extra: ['video_extras']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'tabular', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'text', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'pointcloud', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'serve', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'graph', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'audio', extra: []} # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 50 @@ -85,10 +85,6 @@ jobs: open(fname, 'w').writelines(lines) shell: python - - run: echo "period=$(python -c 'import time; days = time.time() / 60 / 60 / 24; print(int(days / 7))' 2>&1)" >> $GITHUB_OUTPUT - if: matrix.requires != 'latest' - id: times - - name: Install graph test dependencies if: contains( matrix.topic , 'graph' ) run: | @@ -97,14 +93,21 @@ jobs: pip install torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cpu.html pip install torch-cluster -f https://data.pyg.org/whl/torch-1.11.0+cpu.html - - name: Install dependencies + - name: Adjust extras + run: | + import os + extras = ['${{ matrix.topic }}'] + ${{ toJSON(matrix.extra) }} + with open(os.getenv('GITHUB_ENV'), "a") as gh_env: + gh_env.write(f"EXTRAS={','.join(extras)}") + shell: python + + - name: Install package & dependencies + env: + SYSTEM_VERSION_COMPAT: 1 run: | - python --version pip --version - flag=$(python -c "print('--pre' if '${{matrix.release}}' == 'pre' else '')" 2>&1) - pip install torch>=1.7.1 - pip install '.[${{ join(matrix.topic, ',') }}]' --upgrade $flag --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - pip install '.[test]' --upgrade + pip install cython "torch>=1.7.1" -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install .[$EXTRAS,test] --upgrade --prefer-binary --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - name: Install vissl if: contains( matrix.topic , 'image_extras' ) @@ -139,10 +142,15 @@ jobs: run: | pip list # FixMe: include doctests for src/ - coverage run --source flash -m pytest tests/ -v --reruns 3 --reruns-delay 2 + coverage run --source flash -m pytest \ + tests/core \ + tests/deprecated_api \ + tests/examples \ + tests/template \ + tests/${{ matrix.topic }} \ + -v # --reruns 3 --reruns-delay 2 - name: Statistics - if: success() run: | coverage report coverage xml diff --git a/examples/video_classification.py b/examples/video_classification.py index c504335e86..fc66c8ab00 100644 --- a/examples/video_classification.py +++ b/examples/video_classification.py @@ -34,9 +34,7 @@ model = VideoClassifier(backbone="x3d_xs", labels=datamodule.labels, pretrained=False) # 3. Create the trainer and finetune the model -trainer = flash.Trainer( - max_epochs=1, gpus=torch.cuda.device_count(), strategy="ddp" if torch.cuda.device_count() > 1 else None -) +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count() if torch.cuda.device_count() > 1 else None) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Make a prediction diff --git a/requirements.txt b/requirements.txt index e7da88a101..339b2bbead 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup packaging -setuptools<=59.5.0 # Prevent install bug with tensorboard -numpy<1.24 # strict - freeze for using np.long -torch>=1.7.1 -torchmetrics>0.5.1, <0.11.0 # strict -pytorch-lightning>=1.3.6, <1.9.0 # strict +setuptools <=59.5.0 # Prevent install bug with tensorboard +numpy <1.24 # strict - freeze for using np.long +torch >=1.7.1 +torchmetrics >0.5.1, <0.11.0 # strict +pytorch-lightning >=1.3.6, <1.9.0 # strict pyDeprecate -pandas>=1.1.0, <=1.5.2 -jsonargparse[signatures]>=3.17.0, <=4.9.0 -click>=7.1.2, <=8.1.3 +pandas >=1.1.0, <=1.5.2 +jsonargparse[signatures] >=3.17.0, <=4.9.0 +click >=7.1.2, <=8.1.3 protobuf <=3.20.1 -fsspec[http]>=2021.6.1,<=2022.7.1 +fsspec[http] >=2022.5.0,<=2022.7.1 lightning-utilities >=0.4.1 diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 1237792cd6..ea7550693a 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -4,7 +4,7 @@ torchvision <=0.14.1 timm>=0.4.5, <=0.4.12 lightning-bolts>=0.3.3, <=0.6.0 Pillow>=7.2, <=9.3.0 -albumentations>=1.0, <=1.3.0 +albumentations <=1.3.0 pystiche>=1.0.0, <=1.0.1 segmentation-models-pytorch>=0.2.0, <=0.3.1 ftfy <=6.1.1 diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 8b58b7aafa..ff6e9c9c7e 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -4,12 +4,12 @@ matplotlib <=3.6.2 fiftyone classy_vision vissl>=0.1.5 -icevision>=0.8 +icevision >=0.8 sahi >=0.8.19,<0.11.0 icedata effdet kornia>=0.5.1 -learn2learn +learn2learn; platform_system != "Windows" # dead fastface fairscale diff --git a/requirements/datatype_image_extras_baal.txt b/requirements/datatype_image_extras_baal.txt index 8cfaa31311..655862dfd9 100644 --- a/requirements/datatype_image_extras_baal.txt +++ b/requirements/datatype_image_extras_baal.txt @@ -2,3 +2,4 @@ # This is a separate file, as baal integration is affected by vissl installation (conflicts) baal>=1.3.2, <=1.7.0 +icevision >=0.8 diff --git a/src/flash/audio/classification/data.py b/src/flash/audio/classification/data.py index eb21fce711..721c0fa36b 100644 --- a/src/flash/audio/classification/data.py +++ b/src/flash/audio/classification/data.py @@ -32,12 +32,12 @@ from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _AUDIO_TESTING +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.image.classification.data import MatplotlibVisualization # Skip doctests if requirements aren't available -if not _AUDIO_TESTING: +if not _TOPIC_AUDIO_AVAILABLE: __doctest_skip__ = ["AudioClassificationData", "AudioClassificationData.*"] diff --git a/src/flash/audio/speech_recognition/backbone.py b/src/flash/audio/speech_recognition/backbone.py index 3c4298e1e5..84d7c92bfa 100644 --- a/src/flash/audio/speech_recognition/backbone.py +++ b/src/flash/audio/speech_recognition/backbone.py @@ -14,12 +14,12 @@ from functools import partial from flash.core.registry import ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _AUDIO_AVAILABLE +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: from transformers import AutoModelForCTC, Wav2Vec2ForCTC WAV2VEC_MODELS = ["facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"] diff --git a/src/flash/audio/speech_recognition/collate.py b/src/flash/audio/speech_recognition/collate.py index a8723c9a4b..0346bb04f4 100644 --- a/src/flash/audio/speech_recognition/collate.py +++ b/src/flash/audio/speech_recognition/collate.py @@ -17,9 +17,9 @@ from torch import Tensor from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _AUDIO_AVAILABLE +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: from transformers import AutoProcessor else: AutoProcessor = object diff --git a/src/flash/audio/speech_recognition/data.py b/src/flash/audio/speech_recognition/data.py index 4b79bcafa5..86127c41e6 100644 --- a/src/flash/audio/speech_recognition/data.py +++ b/src/flash/audio/speech_recognition/data.py @@ -25,11 +25,11 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform -from flash.core.utilities.imports import _AUDIO_TESTING +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE from flash.core.utilities.stages import RunningStage # Skip doctests if requirements aren't available -if not _AUDIO_TESTING: +if not _TOPIC_AUDIO_AVAILABLE: __doctest_skip__ = ["SpeechRecognitionData", "SpeechRecognitionData.*"] diff --git a/src/flash/audio/speech_recognition/input.py b/src/flash/audio/speech_recognition/input.py index d98ec62580..8e517e3e9c 100644 --- a/src/flash/audio/speech_recognition/input.py +++ b/src/flash/audio/speech_recognition/input.py @@ -24,9 +24,9 @@ from flash.core.data.utilities.loading import AUDIO_EXTENSIONS, load_audio, load_data_frame from flash.core.data.utilities.paths import filter_valid_files, list_valid_files from flash.core.data.utilities.samples import to_sample, to_samples -from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, requires -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: import librosa from datasets import Dataset as HFDataset from datasets import load_dataset diff --git a/src/flash/audio/speech_recognition/model.py b/src/flash/audio/speech_recognition/model.py index 47869c71d4..a3c379c642 100644 --- a/src/flash/audio/speech_recognition/model.py +++ b/src/flash/audio/speech_recognition/model.py @@ -28,10 +28,10 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.serve import Composition -from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, requires from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, OPTIMIZER_TYPE -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: from transformers import AutoProcessor diff --git a/src/flash/audio/speech_recognition/output_transform.py b/src/flash/audio/speech_recognition/output_transform.py index ecdc5f326c..1cd1314106 100644 --- a/src/flash/audio/speech_recognition/output_transform.py +++ b/src/flash/audio/speech_recognition/output_transform.py @@ -16,9 +16,9 @@ import torch from flash.core.data.io.output_transform import OutputTransform -from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, requires -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: from transformers import Wav2Vec2CTCTokenizer diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py index 5c4cae6a99..8fb75e5442 100644 --- a/src/flash/core/data/data_module.py +++ b/src/flash/core/data/data_module.py @@ -33,12 +33,12 @@ ) from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["DataModule"] diff --git a/src/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py index de78e6ee88..6ba85486a1 100644 --- a/src/flash/core/data/utilities/classification.py +++ b/src/flash/core/data/utilities/classification.py @@ -20,10 +20,10 @@ from torch import Tensor from flash.core.data.utilities.sort import sorted_alphanumeric -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["*"] diff --git a/src/flash/core/data/utilities/loading.py b/src/flash/core/data/utilities/loading.py index b9111bb723..44c7ba8de4 100644 --- a/src/flash/core/data/utilities/loading.py +++ b/src/flash/core/data/utilities/loading.py @@ -25,9 +25,9 @@ import torch from flash.core.data.utilities.paths import has_file_allowed_extension -from flash.core.utilities.imports import _AUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, Image +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, Image -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: from torchaudio.transforms import Spectrogram if _TORCHVISION_AVAILABLE: diff --git a/src/flash/core/data/utils.py b/src/flash/core/data/utils.py index fef288ca9f..3c420f682a 100644 --- a/src/flash/core/data/utils.py +++ b/src/flash/core/data/utils.py @@ -23,11 +23,11 @@ from torch import nn from tqdm.auto import tqdm as tq -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["download_data"] _STAGES_PREFIX = { diff --git a/src/flash/core/integrations/icevision/transforms.py b/src/flash/core/integrations/icevision/transforms.py index d7201a146d..e7a521ef63 100644 --- a/src/flash/core/integrations/icevision/transforms.py +++ b/src/flash/core/integrations/icevision/transforms.py @@ -22,11 +22,11 @@ from flash.core.utilities.imports import ( _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, - _IMAGE_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, requires, ) -if _IMAGE_AVAILABLE: +if _TOPIC_IMAGE_AVAILABLE: from PIL import Image if _ICEVISION_AVAILABLE: diff --git a/src/flash/core/model.py b/src/flash/core/model.py index e8a06fbf02..1bc19203f1 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -47,7 +47,7 @@ from flash.core.registry import FlashRegistry from flash.core.serve.composition import Composition from flash.core.utilities.apply_func import get_callable_dict -from flash.core.utilities.imports import _CORE_TESTING, _PL_GREATER_EQUAL_1_5_0, requires +from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _TOPIC_CORE_AVAILABLE, requires from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import ( @@ -61,7 +61,7 @@ ) # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["Task", "Task.*"] diff --git a/src/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py index 3d96231ff9..53c1fb7038 100644 --- a/src/flash/core/optimizers/lamb.py +++ b/src/flash/core/optimizers/lamb.py @@ -24,10 +24,10 @@ import torch from torch.optim.optimizer import Optimizer -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["LAMB"] diff --git a/src/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py index 00b0dc8531..fd3334e0e5 100644 --- a/src/flash/core/optimizers/lars.py +++ b/src/flash/core/optimizers/lars.py @@ -21,10 +21,10 @@ import torch from torch.optim.optimizer import Optimizer, required -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["LARS"] diff --git a/src/flash/core/optimizers/lr_scheduler.py b/src/flash/core/optimizers/lr_scheduler.py index f5ffe7e6b6..e0e918ca9f 100644 --- a/src/flash/core/optimizers/lr_scheduler.py +++ b/src/flash/core/optimizers/lr_scheduler.py @@ -22,10 +22,10 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["LinearWarmupCosineAnnealingLR"] diff --git a/src/flash/core/serve/component.py b/src/flash/core/serve/component.py index 990b0132e8..2c7dad8e77 100644 --- a/src/flash/core/serve/component.py +++ b/src/flash/core/serve/component.py @@ -7,7 +7,7 @@ from flash.core.serve.core import ParameterContainer, Servable from flash.core.serve.decorators import BoundMeta, UnboundMeta -from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_AVAILABLE, requires +from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _TOPIC_SERVE_AVAILABLE, requires if _CYTOOLZ_AVAILABLE: from cytoolz import first, isiterable, valfilter @@ -191,7 +191,7 @@ def __call__(cls, *args, **kwargs): return klass -if _SERVE_AVAILABLE: +if _TOPIC_SERVE_AVAILABLE: class ModelComponent(metaclass=FlashServeMeta): """Represents a computation which is decorated by `@expose`. diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index e5e6f5bd3d..2909d54e85 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -4,10 +4,10 @@ from flash.core.serve.dag.task import flatten, get, get_dependencies, ishashable, istask, reverse_dict, subs, toposort from flash.core.serve.dag.utils import key_split -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["*"] diff --git a/src/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py index cccd9a6107..6d4ae93688 100644 --- a/src/flash/core/serve/dag/order.py +++ b/src/flash/core/serve/dag/order.py @@ -80,10 +80,10 @@ from math import log from flash.core.serve.dag.task import get_dependencies, getcycle, reverse_dict -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["*"] diff --git a/src/flash/core/serve/dag/rewrite.py b/src/flash/core/serve/dag/rewrite.py index 8a01bea450..63ef792904 100644 --- a/src/flash/core/serve/dag/rewrite.py +++ b/src/flash/core/serve/dag/rewrite.py @@ -1,10 +1,10 @@ from collections import deque from flash.core.serve.dag.task import istask, subs -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["*"] diff --git a/src/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py index 6f7c8bebde..d1bb72c4f7 100644 --- a/src/flash/core/serve/dag/task.py +++ b/src/flash/core/serve/dag/task.py @@ -1,10 +1,10 @@ from collections import defaultdict from typing import List, Sequence -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["*"] no_default = "__no_default__" diff --git a/src/flash/core/serve/dag/utils.py b/src/flash/core/serve/dag/utils.py index fd4a9ea818..c7e8925a6d 100644 --- a/src/flash/core/serve/dag/utils.py +++ b/src/flash/core/serve/dag/utils.py @@ -6,10 +6,10 @@ import re from operator import methodcaller -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["*"] diff --git a/src/flash/core/serve/decorators.py b/src/flash/core/serve/decorators.py index 0c1ecc9708..0675d037ee 100644 --- a/src/flash/core/serve/decorators.py +++ b/src/flash/core/serve/decorators.py @@ -8,10 +8,10 @@ from flash.core.serve.core import Connection, ParameterContainer, Servable, make_param_dict, make_parameter_container from flash.core.serve.types.base import BaseType from flash.core.serve.utils import fn_outputs_to_keyed_map -from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["*"] if _CYTOOLZ_AVAILABLE: diff --git a/src/flash/core/serve/interfaces/models.py b/src/flash/core/serve/interfaces/models.py index f3884936af..4d6c84b5b7 100644 --- a/src/flash/core/serve/interfaces/models.py +++ b/src/flash/core/serve/interfaces/models.py @@ -3,10 +3,10 @@ from flash.core.serve.component import ModelComponent from flash.core.serve.core import Endpoint from flash.core.serve.types import Repeated -from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, _TOPIC_SERVE_AVAILABLE # Skip doctests if requirements aren't available -if not _SERVE_TESTING: +if not _TOPIC_SERVE_AVAILABLE: __doctest_skip__ = ["EndpointProtocol.*"] if _PYDANTIC_AVAILABLE: diff --git a/src/flash/core/utilities/imports.py b/src/flash/core/utilities/imports.py index 810e878b91..07adb6f7ad 100644 --- a/src/flash/core/utilities/imports.py +++ b/src/flash/core/utilities/imports.py @@ -14,7 +14,6 @@ import functools import importlib import operator -import os import types from typing import List, Tuple, Union @@ -71,6 +70,7 @@ _TORCH_OPTIMIZER_AVAILABLE = module_available("torch_optimizer") _SENTENCE_TRANSFORMERS_AVAILABLE = module_available("sentence_transformers") _DEEPSPEED_AVAILABLE = module_available("deepspeed") +_EFFDET_AVAILABLE = module_available("effdet") if _PIL_AVAILABLE: @@ -94,7 +94,7 @@ class Image: _TM_GREATER_EQUAL_0_10_0 = compare_version("torchmetrics", operator.ge, "0.10.0") _BAAL_GREATER_EQUAL_1_5_2 = compare_version("baal", operator.ge, "1.5.2") -_TEXT_AVAILABLE = all( +_TOPIC_TEXT_AVAILABLE = all( [ _TRANSFORMERS_AVAILABLE, _SENTENCEPIECE_AVAILABLE, @@ -103,9 +103,9 @@ class Image: _SENTENCE_TRANSFORMERS_AVAILABLE, ] ) -_TABULAR_AVAILABLE = _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE and _PYTORCHTABULAR_AVAILABLE -_VIDEO_AVAILABLE = _TORCHVISION_AVAILABLE and _PIL_AVAILABLE and _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE -_IMAGE_AVAILABLE = all( +_TOPIC_TABULAR_AVAILABLE = all([_PANDAS_AVAILABLE, _FORECASTING_AVAILABLE, _PYTORCHTABULAR_AVAILABLE]) +_TOPIC_VIDEO_AVAILABLE = all([_TORCHVISION_AVAILABLE, _PIL_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, _KORNIA_AVAILABLE]) +_TOPIC_IMAGE_AVAILABLE = all( [ _TORCHVISION_AVAILABLE, _TIMM_AVAILABLE, @@ -115,22 +115,25 @@ class Image: _SEGMENTATION_MODELS_AVAILABLE, ] ) -_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE -_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE -_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE]) -_GRAPH_AVAILABLE = ( - _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE and _NETWORKX_AVAILABLE +_TOPIC_SERVE_AVAILABLE = all([_FASTAPI_AVAILABLE, _PYDANTIC_AVAILABLE, _CYTOOLZ_AVAILABLE, _UVICORN_AVAILABLE]) +_TOPIC_POINTCLOUD_AVAILABLE = all([_OPEN3D_AVAILABLE, _TORCHVISION_AVAILABLE]) +_TOPIC_AUDIO_AVAILABLE = all( + [_TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE] ) +_TOPIC_GRAPH_AVAILABLE = all( + [_TORCH_SCATTER_AVAILABLE, _TORCH_SPARSE_AVAILABLE, _TORCH_GEOMETRIC_AVAILABLE, _NETWORKX_AVAILABLE] +) +_TOPIC_CORE_AVAILABLE = _TOPIC_IMAGE_AVAILABLE and _TOPIC_TABULAR_AVAILABLE and _TOPIC_TEXT_AVAILABLE _EXTRAS_AVAILABLE = { - "image": _IMAGE_AVAILABLE, - "tabular": _TABULAR_AVAILABLE, - "text": _TEXT_AVAILABLE, - "video": _VIDEO_AVAILABLE, - "pointcloud": _POINTCLOUD_AVAILABLE, - "serve": _SERVE_AVAILABLE, - "audio": _AUDIO_AVAILABLE, - "graph": _GRAPH_AVAILABLE, + "image": _TOPIC_IMAGE_AVAILABLE, + "tabular": _TOPIC_TABULAR_AVAILABLE, + "text": _TOPIC_TEXT_AVAILABLE, + "video": _TOPIC_VIDEO_AVAILABLE, + "pointcloud": _TOPIC_POINTCLOUD_AVAILABLE, + "serve": _TOPIC_SERVE_AVAILABLE, + "audio": _TOPIC_AUDIO_AVAILABLE, + "graph": _TOPIC_GRAPH_AVAILABLE, } @@ -237,31 +240,3 @@ def _import_module(self): # Update this object's dict so that attribute references are efficient # (__getattr__ is only called on lookups that fail) self.__dict__.update(module.__dict__) - - -# Global variables used for testing purposes (e.g. to only run doctests in the correct CI job) -_CORE_TESTING = True -_IMAGE_TESTING = _IMAGE_AVAILABLE -_IMAGE_EXTRAS_TESTING = True # Not for normal use -_VIDEO_TESTING = _VIDEO_AVAILABLE -_VIDEO_EXTRAS_TESTING = True # Not for normal use -_TABULAR_TESTING = _TABULAR_AVAILABLE -_TEXT_TESTING = _TEXT_AVAILABLE -_SERVE_TESTING = _SERVE_AVAILABLE -_POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE -_GRAPH_TESTING = _GRAPH_AVAILABLE -_AUDIO_TESTING = _AUDIO_AVAILABLE - -if "FLASH_TEST_TOPIC" in os.environ: - topic = os.environ["FLASH_TEST_TOPIC"] - _CORE_TESTING = topic == "core" - _IMAGE_TESTING = topic == "image" - _IMAGE_EXTRAS_TESTING = topic == "image,image_extras" or topic == "icevision" or topic == "vissl" - _VIDEO_TESTING = topic == "video" - _VIDEO_EXTRAS_TESTING = topic == "video,video_extras" - _TABULAR_TESTING = topic == "tabular" - _TEXT_TESTING = topic == "text" - _SERVE_TESTING = topic == "serve" - _POINTCLOUD_TESTING = topic == "pointcloud" - _GRAPH_TESTING = topic == "graph" - _AUDIO_TESTING = topic == "audio" diff --git a/src/flash/core/utilities/stability.py b/src/flash/core/utilities/stability.py index 45241c853c..16e29a0d7d 100644 --- a/src/flash/core/utilities/stability.py +++ b/src/flash/core/utilities/stability.py @@ -17,10 +17,10 @@ from pytorch_lightning.utilities import rank_zero_warn -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # Skip doctests if requirements aren't available -if not _CORE_TESTING: +if not _TOPIC_CORE_AVAILABLE: __doctest_skip__ = ["beta"] diff --git a/src/flash/graph/backbones.py b/src/flash/graph/backbones.py index d09262c569..d2186463d3 100644 --- a/src/flash/graph/backbones.py +++ b/src/flash/graph/backbones.py @@ -14,10 +14,10 @@ from functools import partial from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE from flash.core.utilities.providers import _PYTORCH_GEOMETRIC -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN} diff --git a/src/flash/graph/classification/data.py b/src/flash/graph/classification/data.py index 86936406ba..b0c909b53d 100644 --- a/src/flash/graph/classification/data.py +++ b/src/flash/graph/classification/data.py @@ -18,14 +18,14 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.utilities.classification import TargetFormatter -from flash.core.utilities.imports import _GRAPH_TESTING +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.graph.classification.input import GraphClassificationDatasetInput from flash.graph.classification.input_transform import GraphClassificationInputTransform # Skip doctests if requirements aren't available -if not _GRAPH_TESTING: +if not _TOPIC_GRAPH_AVAILABLE: __doctest_skip__ = ["GraphClassificationData", "GraphClassificationData.*"] diff --git a/src/flash/graph/classification/input.py b/src/flash/graph/classification/input.py index 2fda1bc32a..3f6ef9060b 100644 --- a/src/flash/graph/classification/input.py +++ b/src/flash/graph/classification/input.py @@ -19,9 +19,9 @@ from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.samples import to_sample -from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE, requires -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric.data import Data, InMemoryDataset diff --git a/src/flash/graph/classification/input_transform.py b/src/flash/graph/classification/input_transform.py index bce9236f36..66a0134f00 100644 --- a/src/flash/graph/classification/input_transform.py +++ b/src/flash/graph/classification/input_transform.py @@ -17,10 +17,10 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.samples import to_sample -from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE from flash.graph.collate import _pyg_collate -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric.data import Data from torch_geometric.transforms import NormalizeFeatures else: diff --git a/src/flash/graph/classification/model.py b/src/flash/graph/classification/model.py index 41556e961a..de7abbf07c 100644 --- a/src/flash/graph/classification/model.py +++ b/src/flash/graph/classification/model.py @@ -20,12 +20,12 @@ from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.graph.backbones import GRAPH_BACKBONES from flash.graph.collate import _pyg_collate -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool POOLING_FUNCTIONS = {"mean": global_mean_pool, "add": global_add_pool, "max": global_max_pool} diff --git a/src/flash/graph/collate.py b/src/flash/graph/collate.py index aba0bcab96..a4681e204f 100644 --- a/src/flash/graph/collate.py +++ b/src/flash/graph/collate.py @@ -16,9 +16,9 @@ from torch.utils.data.dataloader import default_collate from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric.data import Batch diff --git a/src/flash/image/classification/data.py b/src/flash/image/classification/data.py index 5a90d78dfe..55356bfea6 100644 --- a/src/flash/image/classification/data.py +++ b/src/flash/image/classification/data.py @@ -28,9 +28,8 @@ from flash.core.integrations.labelstudio.input import LabelStudioImageClassificationInput, _parse_labelstudio_arguments from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, - _IMAGE_EXTRAS_TESTING, - _IMAGE_TESTING, _MATPLOTLIB_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, Image, requires, ) @@ -59,7 +58,7 @@ # Skip doctests if requirements aren't available __doctest_skip__ = [] -if not _IMAGE_TESTING: +if not _TOPIC_IMAGE_AVAILABLE: __doctest_skip__ += [ "ImageClassificationData", "ImageClassificationData.from_files", @@ -69,9 +68,8 @@ "ImageClassificationData.from_tensors", "ImageClassificationData.from_data_frame", "ImageClassificationData.from_csv", + "ImageClassificationData.from_fiftyone", ] -if not _IMAGE_EXTRAS_TESTING: - __doctest_skip__ += ["ImageClassificationData.from_fiftyone"] class ImageClassificationData(DataModule): diff --git a/src/flash/image/detection/data.py b/src/flash/image/detection/data.py index 98ac401d73..aafb8ce99d 100644 --- a/src/flash/image/detection/data.py +++ b/src/flash/image/detection/data.py @@ -26,7 +26,7 @@ from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, - _IMAGE_EXTRAS_TESTING, + _TOPIC_IMAGE_AVAILABLE, Image, requires, ) @@ -55,7 +55,7 @@ Parser = object # Skip doctests if requirements aren't available -if not _IMAGE_EXTRAS_TESTING: +if not _TOPIC_IMAGE_AVAILABLE: __doctest_skip__ = ["ObjectDetectionData", "ObjectDetectionData.*"] diff --git a/src/flash/image/instance_segmentation/data.py b/src/flash/image/instance_segmentation/data.py index 476901a629..9676a629ff 100644 --- a/src/flash/image/instance_segmentation/data.py +++ b/src/flash/image/instance_segmentation/data.py @@ -24,7 +24,7 @@ from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.utilities.imports import ( _ICEVISION_AVAILABLE, - _IMAGE_EXTRAS_TESTING, + _TOPIC_IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_9, ) @@ -51,7 +51,7 @@ class InterpolationMode: # Skip doctests if requirements aren't available -if not _IMAGE_EXTRAS_TESTING: +if not _TOPIC_IMAGE_AVAILABLE: __doctest_skip__ = ["InstanceSegmentationData", "InstanceSegmentationData.*"] diff --git a/src/flash/image/keypoint_detection/data.py b/src/flash/image/keypoint_detection/data.py index d2d422ca79..f4324359fa 100644 --- a/src/flash/image/keypoint_detection/data.py +++ b/src/flash/image/keypoint_detection/data.py @@ -17,7 +17,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.integrations.icevision.data import IceVisionInput -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform @@ -32,7 +32,7 @@ # Skip doctests if requirements aren't available -if not _IMAGE_EXTRAS_TESTING: +if not _TOPIC_IMAGE_AVAILABLE: __doctest_skip__ = ["KeypointDetectionData", "KeypointDetectionData.*"] diff --git a/src/flash/image/segmentation/data.py b/src/flash/image/segmentation/data.py index 4ac4dbd774..815d041b19 100644 --- a/src/flash/image/segmentation/data.py +++ b/src/flash/image/segmentation/data.py @@ -19,7 +19,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _IMAGE_TESTING, lazy_import +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, lazy_import from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.segmentation.input import ( @@ -41,16 +41,15 @@ # Skip doctests if requirements aren't available __doctest_skip__ = [] -if not _IMAGE_TESTING: +if not _TOPIC_IMAGE_AVAILABLE: __doctest_skip__ += [ "SemanticSegmentationData", "SemanticSegmentationData.from_files", "SemanticSegmentationData.from_folders", "SemanticSegmentationData.from_numpy", "SemanticSegmentationData.from_tensors", + "SemanticSegmentationData.from_fiftyone", ] -if not _IMAGE_EXTRAS_TESTING: - __doctest_skip__ += ["SemanticSegmentationData.from_fiftyone"] class SemanticSegmentationData(DataModule): diff --git a/src/flash/image/style_transfer/data.py b/src/flash/image/style_transfer/data.py index fea37d99c5..22a9bbda97 100644 --- a/src/flash/image/style_transfer/data.py +++ b/src/flash/image/style_transfer/data.py @@ -18,7 +18,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input -from flash.core.utilities.imports import _IMAGE_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE from flash.core.utilities.stability import beta from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -27,7 +27,7 @@ from flash.image.style_transfer.input_transform import StyleTransferInputTransform # Skip doctests if requirements aren't available -if not _IMAGE_TESTING: +if not _TOPIC_IMAGE_AVAILABLE: __doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"] diff --git a/src/flash/image/style_transfer/model.py b/src/flash/image/style_transfer/model.py index 84eb62a04f..eb67e88961 100644 --- a/src/flash/image/style_transfer/model.py +++ b/src/flash/image/style_transfer/model.py @@ -18,12 +18,12 @@ from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _IMAGE_AVAILABLE +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE from flash.core.utilities.stability import beta from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE from flash.image.style_transfer import STYLE_TRANSFER_BACKBONES -if _IMAGE_AVAILABLE: +if _TOPIC_IMAGE_AVAILABLE: import pystiche.demo from pystiche import enc, loss from pystiche.image import read_image diff --git a/src/flash/pointcloud/detection/datasets.py b/src/flash/pointcloud/detection/datasets.py index 335f699757..cbd5de4772 100644 --- a/src/flash/pointcloud/detection/datasets.py +++ b/src/flash/pointcloud/detection/datasets.py @@ -14,10 +14,10 @@ import os from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.pointcloud.segmentation.datasets import executor -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d.ml.datasets import KITTI _OBJECT_DETECTION_DATASET = FlashRegistry("dataset") diff --git a/src/flash/pointcloud/detection/open3d_ml/app.py b/src/flash/pointcloud/detection/open3d_ml/app.py index c549875597..07f755f1bc 100644 --- a/src/flash/pointcloud/detection/open3d_ml/app.py +++ b/src/flash/pointcloud/detection/open3d_ml/app.py @@ -17,9 +17,9 @@ import flash from flash.core.data.data_module import DataModule from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer from open3d.visualization import gui diff --git a/src/flash/pointcloud/detection/open3d_ml/backbones.py b/src/flash/pointcloud/detection/open3d_ml/backbones.py index ce38bad11b..0d1dfefd1c 100644 --- a/src/flash/pointcloud/detection/open3d_ml/backbones.py +++ b/src/flash/pointcloud/detection/open3d_ml/backbones.py @@ -19,13 +19,13 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.core.utilities.providers import _OPEN3D_ML from flash.core.utilities.url_error import catch_url_error ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: import open3d import open3d.ml as _ml3d from open3d._ml3d.torch.dataloaders.concat_batcher import ConcatBatcher, ObjectDetectBatch @@ -50,7 +50,7 @@ def __len__(self): def register_open_3d_ml(register: FlashRegistry): - if _POINTCLOUD_AVAILABLE: + if _TOPIC_POINTCLOUD_AVAILABLE: CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") def get_collate_fn(model) -> Callable: diff --git a/src/flash/pointcloud/detection/open3d_ml/input.py b/src/flash/pointcloud/detection/open3d_ml/input.py index 397d0da49d..30d9d317ff 100644 --- a/src/flash/pointcloud/detection/open3d_ml/input.py +++ b/src/flash/pointcloud/detection/open3d_ml/input.py @@ -18,9 +18,9 @@ import yaml from flash.core.data.io.input import BaseDataFormat, Input -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d._ml3d.datasets.kitti import KITTI, DataProcessing diff --git a/src/flash/pointcloud/segmentation/datasets.py b/src/flash/pointcloud/segmentation/datasets.py index ff792282a4..af0fd28539 100644 --- a/src/flash/pointcloud/segmentation/datasets.py +++ b/src/flash/pointcloud/segmentation/datasets.py @@ -14,9 +14,9 @@ import os from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d.ml.datasets import Lyft, SemanticKITTI _SEGMENTATION_DATASET = FlashRegistry("dataset") diff --git a/src/flash/pointcloud/segmentation/model.py b/src/flash/pointcloud/segmentation/model.py index 1cbb2c0752..79363267a3 100644 --- a/src/flash/pointcloud/segmentation/model.py +++ b/src/flash/pointcloud/segmentation/model.py @@ -24,12 +24,12 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.collate import wrap_collate from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _TM_GREATER_EQUAL_0_7_0, _TM_GREATER_EQUAL_0_10_0 +from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_7_0, _TM_GREATER_EQUAL_0_10_0, _TOPIC_POINTCLOUD_AVAILABLE from flash.core.utilities.stability import beta from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d._ml3d.torch.modules.losses.semseg_loss import filter_valid_label from open3d.ml.torch.dataloaders import TorchDataloader diff --git a/src/flash/pointcloud/segmentation/open3d_ml/app.py b/src/flash/pointcloud/segmentation/open3d_ml/app.py index 929bc93121..7d3eab3f23 100644 --- a/src/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/src/flash/pointcloud/segmentation/open3d_ml/app.py @@ -15,9 +15,9 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d._ml3d.torch.dataloaders import TorchDataloader from open3d._ml3d.vis.visualizer import LabelLUT from open3d._ml3d.vis.visualizer import Visualizer as Open3dVisualizer diff --git a/src/flash/pointcloud/segmentation/open3d_ml/backbones.py b/src/flash/pointcloud/segmentation/open3d_ml/backbones.py index d84cf49e99..960fe4dd2e 100644 --- a/src/flash/pointcloud/segmentation/open3d_ml/backbones.py +++ b/src/flash/pointcloud/segmentation/open3d_ml/backbones.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.core.utilities.providers import _OPEN3D_ML from flash.core.utilities.url_error import catch_url_error @@ -26,7 +26,7 @@ def register_open_3d_ml(register: FlashRegistry): - if _POINTCLOUD_AVAILABLE: + if _TOPIC_POINTCLOUD_AVAILABLE: import open3d import open3d.ml as _ml3d from open3d._ml3d.torch.dataloaders import ConcatBatcher, DefaultBatcher diff --git a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py index e7dfc80cfa..65af4e1315 100644 --- a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py +++ b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -18,9 +18,9 @@ import yaml from torch.utils.data import Dataset -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE -if _POINTCLOUD_AVAILABLE: +if _TOPIC_POINTCLOUD_AVAILABLE: from open3d._ml3d.datasets.utils import DataProcessing from open3d._ml3d.utils.config import Config diff --git a/src/flash/tabular/classification/data.py b/src/flash/tabular/classification/data.py index 1db88edd4e..464ffe59dc 100644 --- a/src/flash/tabular/classification/data.py +++ b/src/flash/tabular/classification/data.py @@ -16,7 +16,7 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.utilities.classification import TargetFormatter -from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TOPIC_TABULAR_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.tabular.classification.input import ( TabularClassificationCSVInput, @@ -32,7 +32,7 @@ DataFrame = object # Skip doctests if requirements aren't available -if not _TABULAR_TESTING: +if not _TOPIC_TABULAR_AVAILABLE: __doctest_skip__ = ["TabularClassificationData", "TabularClassificationData.*"] diff --git a/src/flash/tabular/classification/model.py b/src/flash/tabular/classification/model.py index 1a0e6674e0..78c64fc970 100644 --- a/src/flash/tabular/classification/model.py +++ b/src/flash/tabular/classification/model.py @@ -23,12 +23,12 @@ from flash.core.integrations.pytorch_tabular.backbones import PYTORCH_TABULAR_BACKBONES from flash.core.registry import FlashRegistry from flash.core.serve import Composition -from flash.core.utilities.imports import _TABULAR_TESTING, requires +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE, requires from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.tabular.input import TabularDeserializer # Skip doctests if requirements aren't available -if not _TABULAR_TESTING: +if not _TOPIC_TABULAR_AVAILABLE: __doctest_skip__ = ["TabularClassifier", "TabularClassifier.*"] diff --git a/src/flash/tabular/forecasting/data.py b/src/flash/tabular/forecasting/data.py index abf4094ad0..36e1fc2f01 100644 --- a/src/flash/tabular/forecasting/data.py +++ b/src/flash/tabular/forecasting/data.py @@ -19,7 +19,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform -from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TOPIC_TABULAR_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.tabular.forecasting.input import TabularForecastingDataFrameInput @@ -30,7 +30,7 @@ # Skip doctests if requirements aren't available -if not _TABULAR_TESTING: +if not _TOPIC_TABULAR_AVAILABLE: __doctest_skip__ = ["TabularForecastingData", "TabularForecastingData.*"] diff --git a/src/flash/tabular/regression/data.py b/src/flash/tabular/regression/data.py index d9c714ab08..1676f6aef6 100644 --- a/src/flash/tabular/regression/data.py +++ b/src/flash/tabular/regression/data.py @@ -15,7 +15,7 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform -from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TOPIC_TABULAR_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.tabular.data import TabularData from flash.tabular.regression.input import ( @@ -31,7 +31,7 @@ DataFrame = object # Skip doctests if requirements aren't available -if not _TABULAR_TESTING: +if not _TOPIC_TABULAR_AVAILABLE: __doctest_skip__ = ["TabularRegressionData", "TabularRegressionData.*"] diff --git a/src/flash/tabular/regression/model.py b/src/flash/tabular/regression/model.py index 0840c5cab1..531cacd5d8 100644 --- a/src/flash/tabular/regression/model.py +++ b/src/flash/tabular/regression/model.py @@ -23,12 +23,12 @@ from flash.core.registry import FlashRegistry from flash.core.regression import RegressionAdapterTask from flash.core.serve import Composition -from flash.core.utilities.imports import _TABULAR_TESTING, requires +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE, requires from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.tabular.input import TabularDeserializer # Skip doctests if requirements aren't available -if not _TABULAR_TESTING: +if not _TOPIC_TABULAR_AVAILABLE: __doctest_skip__ = ["TabularRegressor", "TabularRegressor.*"] diff --git a/src/flash/text/classification/data.py b/src/flash/text/classification/data.py index 38fb909f44..05e2a5c42b 100644 --- a/src/flash/text/classification/data.py +++ b/src/flash/text/classification/data.py @@ -21,7 +21,7 @@ from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import LabelStudioTextClassificationInput, _parse_labelstudio_arguments -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.text.classification.input import ( TextClassificationCSVInput, @@ -32,13 +32,13 @@ TextClassificationParquetInput, ) -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset else: Dataset = object # Skip doctests if requirements aren't available -if not _TEXT_TESTING: +if not _TOPIC_TEXT_AVAILABLE: __doctest_skip__ = ["TextClassificationData", "TextClassificationData.*"] diff --git a/src/flash/text/classification/input.py b/src/flash/text/classification/input.py index 73f0a9f0c8..54aae52532 100644 --- a/src/flash/text/classification/input.py +++ b/src/flash/text/classification/input.py @@ -21,9 +21,9 @@ from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _TEXT_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE, requires -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset, load_dataset else: Dataset = object diff --git a/src/flash/text/embedding/backbones.py b/src/flash/text/embedding/backbones.py index c421e0179e..e557e115a8 100644 --- a/src/flash/text/embedding/backbones.py +++ b/src/flash/text/embedding/backbones.py @@ -1,8 +1,8 @@ from flash.core.registry import ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from transformers import AutoModel HUGGINGFACE_BACKBONES = ExternalRegistry( diff --git a/src/flash/text/embedding/model.py b/src/flash/text/embedding/model.py index a170265e28..9c73b975b5 100644 --- a/src/flash/text/embedding/model.py +++ b/src/flash/text/embedding/model.py @@ -22,13 +22,13 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry, print_provider_info -from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS from flash.text.classification.collate import TextClassificationCollate from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES from flash.text.ort_callback import ORTCallback -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from sentence_transformers.models import Pooling Pooling = print_provider_info("Pooling", _SENTENCE_TRANSFORMERS, Pooling) diff --git a/src/flash/text/question_answering/data.py b/src/flash/text/question_answering/data.py index 06b7380166..f87a41faf8 100644 --- a/src/flash/text/question_answering/data.py +++ b/src/flash/text/question_answering/data.py @@ -17,7 +17,7 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.text.question_answering.input import ( @@ -28,7 +28,7 @@ ) # Skip doctests if requirements aren't available -if not _TEXT_AVAILABLE: +if not _TOPIC_TEXT_AVAILABLE: __doctest_skip__ = ["QuestionAnsweringData", "QuestionAnsweringData.*"] diff --git a/src/flash/text/question_answering/input.py b/src/flash/text/question_answering/input.py index 60302fe505..8517ea9c14 100644 --- a/src/flash/text/question_answering/input.py +++ b/src/flash/text/question_answering/input.py @@ -23,9 +23,9 @@ from flash.core.data.io.input import Input from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _TEXT_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE, requires -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset, load_dataset else: Dataset = object diff --git a/src/flash/text/question_answering/model.py b/src/flash/text/question_answering/model.py index 9405451375..478c2778a9 100644 --- a/src/flash/text/question_answering/model.py +++ b/src/flash/text/question_answering/model.py @@ -32,14 +32,14 @@ from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TM_GREATER_EQUAL_0_7_0 +from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_7_0, _TOPIC_TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.ort_callback import ORTCallback from flash.text.question_answering.collate import TextQuestionAnsweringCollate from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from transformers import AutoModelForQuestionAnswering HUGGINGFACE_BACKBONES = ExternalRegistry( diff --git a/src/flash/text/seq2seq/core/input.py b/src/flash/text/seq2seq/core/input.py index 01421fa8c1..857aa75bdb 100644 --- a/src/flash/text/seq2seq/core/input.py +++ b/src/flash/text/seq2seq/core/input.py @@ -17,9 +17,9 @@ from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _TEXT_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE, requires -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset, load_dataset else: Dataset = object diff --git a/src/flash/text/seq2seq/core/model.py b/src/flash/text/seq2seq/core/model.py index 75fea8ea23..8a00672c78 100644 --- a/src/flash/text/seq2seq/core/model.py +++ b/src/flash/text/seq2seq/core/model.py @@ -28,7 +28,7 @@ from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.serve import Composition -from flash.core.utilities.imports import _TEXT_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE, requires from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.types import ( INPUT_TRANSFORM_TYPE, @@ -41,7 +41,7 @@ from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.collate import TextSeq2SeqCollate -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from transformers import AutoModelForSeq2SeqLM HUGGINGFACE_BACKBONES = ExternalRegistry( diff --git a/src/flash/text/seq2seq/summarization/data.py b/src/flash/text/seq2seq/summarization/data.py index cb42e03f43..8f7c2899d7 100644 --- a/src/flash/text/seq2seq/summarization/data.py +++ b/src/flash/text/seq2seq/summarization/data.py @@ -17,18 +17,18 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.text.seq2seq.core.input import Seq2SeqCSVInput, Seq2SeqInputBase, Seq2SeqJSONInput, Seq2SeqListInput -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset else: Dataset = object # Skip doctests if requirements aren't available -if not _TEXT_TESTING: +if not _TOPIC_TEXT_AVAILABLE: __doctest_skip__ = ["SummarizationData", "SummarizationData.*"] diff --git a/src/flash/text/seq2seq/translation/data.py b/src/flash/text/seq2seq/translation/data.py index 6aa2a4a76e..e5c8af1c0a 100644 --- a/src/flash/text/seq2seq/translation/data.py +++ b/src/flash/text/seq2seq/translation/data.py @@ -17,18 +17,18 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.text.seq2seq.core.input import Seq2SeqCSVInput, Seq2SeqInputBase, Seq2SeqJSONInput, Seq2SeqListInput -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset else: Dataset = object # Skip doctests if requirements aren't available -if not _TEXT_TESTING: +if not _TOPIC_TEXT_AVAILABLE: __doctest_skip__ = ["TranslationData", "TranslationData.*"] diff --git a/src/flash/video/classification/data.py b/src/flash/video/classification/data.py index bc24b85988..15065c643b 100644 --- a/src/flash/video/classification/data.py +++ b/src/flash/video/classification/data.py @@ -23,13 +23,7 @@ from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import LabelStudioVideoClassificationInput, _parse_labelstudio_arguments -from flash.core.utilities.imports import ( - _FIFTYONE_AVAILABLE, - _PYTORCHVIDEO_AVAILABLE, - _VIDEO_EXTRAS_TESTING, - _VIDEO_TESTING, - requires, -) +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, _TOPIC_VIDEO_AVAILABLE, requires from flash.core.utilities.stages import RunningStage from flash.video.classification.input import ( VideoClassificationCSVInput, @@ -58,7 +52,7 @@ # Skip doctests if requirements aren't available __doctest_skip__ = [] -if not _VIDEO_TESTING: +if not _TOPIC_VIDEO_AVAILABLE: __doctest_skip__ += [ "VideoClassificationData", "VideoClassificationData.from_files", @@ -66,9 +60,8 @@ "VideoClassificationData.from_data_frame", "VideoClassificationData.from_csv", "VideoClassificationData.from_tensors", + "VideoClassificationData.from_fiftyone", ] -if not _VIDEO_EXTRAS_TESTING: - __doctest_skip__ += ["VideoClassificationData.from_fiftyone"] class VideoClassificationData(DataModule): diff --git a/src/flash/video/classification/utils.py b/src/flash/video/classification/utils.py index 5d51ca216e..b585700ba1 100644 --- a/src/flash/video/classification/utils.py +++ b/src/flash/video/classification/utils.py @@ -2,9 +2,9 @@ import torch -from flash.core.utilities.imports import _VIDEO_AVAILABLE +from flash.core.utilities.imports import _TOPIC_VIDEO_AVAILABLE -if _VIDEO_AVAILABLE: +if _TOPIC_VIDEO_AVAILABLE: from pytorchvideo.data.utils import MultiProcessSampler else: MultiProcessSampler = None diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index 306c9859d9..0212c40215 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -21,7 +21,7 @@ import flash from flash.audio import AudioClassificationData -from flash.core.utilities.imports import _AUDIO_TESTING, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TOPIC_AUDIO_AVAILABLE if _PIL_AVAILABLE: from PIL import Image @@ -49,7 +49,7 @@ def _audio_files(_): return [raw_audio_path, raw_audio_path], 1 -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.parametrize("file_generator", [_image_files, _audio_files]) def test_from_filepaths(tmpdir, file_generator): train_images, channels = file_generator(tmpdir) @@ -69,7 +69,7 @@ def test_from_filepaths(tmpdir, file_generator): assert sorted(list(labels.numpy())) == [1, 2] -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.parametrize( "data,from_function", [ @@ -112,7 +112,7 @@ def test_from_data(data, from_function): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_from_filepaths_numpy(tmpdir): tmpdir = Path(tmpdir) @@ -139,7 +139,7 @@ def test_from_filepaths_numpy(tmpdir): assert sorted(list(labels.numpy())) == [1, 2] -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir) @@ -185,7 +185,7 @@ def test_from_filepaths_list_image_paths(tmpdir): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) @@ -220,7 +220,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["per_sample_transform", "per_batch_transform"]) -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -255,7 +255,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): dm.show_val_batch("per_batch_transform") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_from_folders_only_train(tmpdir): seed_everything(42) @@ -278,7 +278,7 @@ def test_from_folders_only_train(tmpdir): assert labels.shape == (1,) -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_from_folders_train_val(tmpdir): seed_everything(42) @@ -319,7 +319,7 @@ def test_from_folders_train_val(tmpdir): assert list(labels.numpy()) == [0, 0] -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_from_filepaths_multilabel(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/audio/classification/test_model.py b/tests/audio/classification/test_model.py index 0811ea4834..9bd2c83782 100644 --- a/tests/audio/classification/test_model.py +++ b/tests/audio/classification/test_model.py @@ -16,11 +16,11 @@ import pytest from flash.__main__ import main -from flash.core.utilities.imports import _AUDIO_TESTING, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TOPIC_IMAGE_AVAILABLE -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_cli(): cli_args = ["flash", "audio_classification", "--trainer.fast_dev_run", "True"] with mock.patch("sys.argv", cli_args): diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py index dd75e696ac..37ffa918f8 100644 --- a/tests/audio/speech_recognition/test_data.py +++ b/tests/audio/speech_recognition/test_data.py @@ -20,7 +20,7 @@ import flash from flash.audio import SpeechRecognitionData from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _AUDIO_AVAILABLE, _AUDIO_TESTING +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE path = str(Path(flash.ASSETS_ROOT) / "example.wav") sample = {"file": path, "text": "example input."} @@ -48,7 +48,7 @@ def json_data(tmpdir, n_samples=5): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="speech libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = SpeechRecognitionData.from_csv("file", "text", train_file=csv_path, batch_size=1, num_workers=0) @@ -58,7 +58,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="speech libraries aren't installed.") def test_stage_test_and_valid(tmpdir): csv_path = csv_data(tmpdir) dm = SpeechRecognitionData.from_csv( @@ -74,7 +74,7 @@ def test_stage_test_and_valid(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="speech libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0) @@ -83,7 +83,7 @@ def test_from_json(tmpdir): assert DataKeys.TARGET in batch -@pytest.mark.skipif(_AUDIO_AVAILABLE, reason="audio libraries are installed.") +@pytest.mark.skipif(_TOPIC_AUDIO_AVAILABLE, reason="audio libraries are installed.") def test_audio_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[audio]"): SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0) diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py index 889bd56be8..6cedcaad59 100644 --- a/tests/audio/speech_recognition/test_data_model_integration.py +++ b/tests/audio/speech_recognition/test_data_model_integration.py @@ -20,7 +20,7 @@ import flash from flash import Trainer from flash.audio import SpeechRecognition, SpeechRecognitionData -from flash.core.utilities.imports import _AUDIO_TESTING +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # tiny model for testing @@ -50,7 +50,7 @@ def json_data(tmpdir, n_samples=5): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_classification_csv(tmpdir): csv_path = csv_data(tmpdir) @@ -67,7 +67,7 @@ def test_classification_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_classification_json(tmpdir): json_path = json_data(tmpdir) diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 19e53b70f7..1271a0acfa 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -21,7 +21,7 @@ from flash.audio import SpeechRecognition from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _AUDIO_AVAILABLE, _AUDIO_TESTING, _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TOPIC_SERVE_AVAILABLE from tests.helpers.task_tester import TaskTester TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # tiny model for testing @@ -31,8 +31,8 @@ class TestSpeechRecognition(TaskTester): task = SpeechRecognition task_kwargs = dict(backbone=TEST_BACKBONE) cli_command = "speech_recognition" - is_testing = _AUDIO_TESTING - is_available = _AUDIO_AVAILABLE + is_testing = _TOPIC_AUDIO_AVAILABLE + is_available = _TOPIC_AUDIO_AVAILABLE scriptable = False @@ -61,13 +61,13 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_modules_to_freeze(): model = SpeechRecognition(backbone=TEST_BACKBONE) assert model.modules_to_freeze() is model.model.wav2vec2 -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SpeechRecognition(backbone=TEST_BACKBONE) diff --git a/tests/conftest.py b/tests/conftest.py index 894c7b55b8..8a6b303cc6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from pytest_mock import MockerFixture from flash.core.serve.decorators import uuid4 # noqa (used in mocker.patch) -from flash.core.utilities.imports import _SERVE_TESTING, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision @@ -60,7 +60,7 @@ def global_datadir(tmp_path_factory, original_global_datadir): return prep_global_datadir(tmp_path_factory, original_global_datadir) -if _SERVE_TESTING: +if _TOPIC_SERVE_AVAILABLE: @pytest.fixture(scope="session") def squeezenet1_1_model(): diff --git a/tests/core/data/io/test_input.py b/tests/core/data/io/test_input.py index 229a35424a..fc668ef64d 100644 --- a/tests/core/data/io/test_input.py +++ b/tests/core/data/io/test_input.py @@ -14,11 +14,11 @@ import pytest from flash.core.data.io.input import Input, IterableInput, ServeInput -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_input_validation(): with pytest.raises(RuntimeError, match="Use `IterableInput` instead."): @@ -39,7 +39,7 @@ def __init__(self, *args, **kwargs): ValidInput(RunningStage.TRAINING) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_iterable_input_validation(): with pytest.raises(RuntimeError, match="Use `Input` instead."): @@ -60,7 +60,7 @@ def __init__(self, *args, **kwargs): ValidIterableInput(RunningStage.TRAINING) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_serve_input(): server_input = ServeInput() assert server_input.serving diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index 15051c7bfb..1f85c7eb42 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -16,10 +16,10 @@ import pytest from flash.core.data.io.output import Output -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_output(): """Tests basic ``Output`` methods.""" my_output = Output() diff --git a/tests/core/data/io/test_output_transform.py b/tests/core/data/io/test_output_transform.py index 7c69a9ad91..d84907e674 100644 --- a/tests/core/data/io/test_output_transform.py +++ b/tests/core/data/io/test_output_transform.py @@ -15,10 +15,10 @@ import torch from flash.core.data.io.output_transform import OutputTransform -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_output_transform(): class CustomOutputTransform(OutputTransform): @staticmethod diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index 95b33c3eae..10bf8082c8 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -23,7 +23,7 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.io.input import DataKeys from flash.core.data.utils import _CALLBACK_FUNCS -from flash.core.utilities.imports import _IMAGE_TESTING, _PIL_AVAILABLE +from flash.core.utilities.imports import _PIL_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.image import ImageClassificationData @@ -83,7 +83,7 @@ def check_reset(self): self.per_batch_transform_called = False -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") class TestBaseViz: def test_base_viz(self, tmpdir): seed_everything(42) diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py index 9b91ff7e0b..f9efad9091 100644 --- a/tests/core/data/test_batch.py +++ b/tests/core/data/test_batch.py @@ -17,7 +17,7 @@ import torch from flash.core.data.batch import default_uncollate -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE Case = namedtuple("Case", ["collated_batch", "uncollated_batch"]) @@ -47,7 +47,7 @@ ] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("case", cases) def test_default_uncollate(case): assert default_uncollate(case.collated_batch) == case.uncollated_batch @@ -62,7 +62,7 @@ def test_default_uncollate(case): ] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("error_case", error_cases) def test_default_uncollate_raises(error_case): with pytest.raises(ValueError, match=error_case.match): diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index 62fd837d70..73d8803436 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -21,11 +21,11 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.model import Task from flash.core.trainer import Trainer -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @mock.patch("pickle.dumps") # need to mock pickle or we get pickle error @mock.patch("torch.save") # need to mock torch.save, or we get pickle error def test_flash_callback(_, __, tmpdir): diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index b98b5a7cbb..f31a3ce82d 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -21,11 +21,11 @@ from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_base_data_fetcher(tmpdir): class CheckData(BaseDataFetcher): def check(self): diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index 84c194fdbe..46f812bceb 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -25,7 +25,7 @@ from flash.core.data.data_module import DataModule, DatasetInput from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform -from flash.core.utilities.imports import _CORE_TESTING, _IMAGE_TESTING, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.stages import RunningStage from tests.helpers.boring_model import BoringModel @@ -33,7 +33,7 @@ import torchvision.transforms as T -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_data_module(): seed_everything(42) @@ -311,7 +311,7 @@ def test_step(self, batch, batch_idx): assert batch[0].shape == torch.Size([2, 1]) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_transformations(tmpdir): transform = TestInputTransform() datamodule = DataModule( @@ -364,7 +364,7 @@ def test_transformations(tmpdir): assert datamodule.input_transform.test_per_sample_transform_called -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_datapipeline_transformations_overridden_by_task(): # define input transforms class ImageInput(Input): @@ -425,7 +425,7 @@ def validation_step(self, batch, batch_idx): trainer.fit(model, datamodule=datamodule) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("sampler, callable", [(mock.MagicMock(), True), (mock.NonCallableMock(), False)]) @mock.patch("flash.core.data.data_module.DataLoader") def test_dataloaders_with_sampler(mock_dataloader, sampler, callable): @@ -453,7 +453,7 @@ def test_dataloaders_with_sampler(mock_dataloader, sampler, callable): assert "sampler" not in kwargs -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_val_split(): datamodule = DataModule( Input(RunningStage.TRAINING, [1] * 100), diff --git a/tests/core/data/test_input_transform.py b/tests/core/data/test_input_transform.py index bb71b07914..7394f62ece 100644 --- a/tests/core/data/test_input_transform.py +++ b/tests/core/data/test_input_transform.py @@ -16,11 +16,11 @@ import pytest from flash.core.data.io.input_transform import InputTransform -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_input_transform(): def fn(x): return x + 1 @@ -128,7 +128,7 @@ def predict_per_batch_transform_on_device(self, *_, **__): return self.custom_transform -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_check_transforms(): input_transform = CustomInputTransform diff --git a/tests/core/data/test_properties.py b/tests/core/data/test_properties.py index 03a79e4dc2..82d781d887 100644 --- a/tests/core/data/test_properties.py +++ b/tests/core/data/test_properties.py @@ -14,11 +14,11 @@ import pytest from flash.core.data.properties import Properties -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stages import RunningStage -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "running_stage", [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] ) diff --git a/tests/core/data/test_splits.py b/tests/core/data/test_splits.py index d48ff72d4a..ea670135c3 100644 --- a/tests/core/data/test_splits.py +++ b/tests/core/data/test_splits.py @@ -18,10 +18,10 @@ from flash.core.data.data_module import DataModule from flash.core.data.splits import SplitDataset -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_split_dataset(): train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1) assert len(train_ds) == 90 @@ -47,7 +47,7 @@ def __len__(self): assert not split_dataset.dataset.is_passed_down -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_misconfiguration(): with pytest.raises(ValueError, match="[0, 99]"): SplitDataset(range(100), indices=[100]) @@ -65,7 +65,7 @@ def test_misconfiguration(): SplitDataset(list(range(100)), indices="not a list") -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_deepcopy(): """Tests that deepcopy works with the ``SplitDataset``.""" dataset = list(range(100)) diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index e99025f615..fbc678acf7 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -18,11 +18,11 @@ from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE class TestApplyToKeys: - @pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "sample, keys, expected", [ @@ -47,7 +47,7 @@ def test_forward(self, sample, keys, expected): else: transform.assert_not_called() - @pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "transform, expected", [ diff --git a/tests/core/data/utilities/test_classification.py b/tests/core/data/utilities/test_classification.py index 709a9c510f..0bfb809e55 100644 --- a/tests/core/data/utilities/test_classification.py +++ b/tests/core/data/utilities/test_classification.py @@ -30,7 +30,7 @@ SpaceDelimitedTargetFormatter, get_target_formatter, ) -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE Case = namedtuple("Case", ["target", "formatted_target", "target_formatter_type", "labels", "num_classes"]) @@ -139,7 +139,7 @@ ] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("case", cases) def test_case(case): formatter = get_target_formatter(case.target) @@ -150,7 +150,7 @@ def test_case(case): assert [formatter(t) for t in case.target] == case.formatted_target -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("case", cases) def test_speed(case): repeats = int(1e5 / len(case.target)) # Approx. a hundred thousand targets @@ -166,10 +166,10 @@ def test_speed(case): formatter = get_target_formatter(targets) end = time.perf_counter() - assert (end - start) / len(targets) < 1e-4 # 0.1ms per target + assert (end - start) / len(targets) < 5e-4 # 0.1ms per target start = time.perf_counter() _ = [formatter(t) for t in targets] end = time.perf_counter() - assert (end - start) / len(targets) < 1e-4 # 0.1ms per target + assert (end - start) / len(targets) < 5e-4 # 0.1ms per target diff --git a/tests/core/data/utilities/test_loading.py b/tests/core/data/utilities/test_loading.py index 15bc0ec6d7..f8a4c8f48f 100644 --- a/tests/core/data/utilities/test_loading.py +++ b/tests/core/data/utilities/test_loading.py @@ -28,15 +28,14 @@ load_spectrogram, ) from flash.core.utilities.imports import ( - _AUDIO_AVAILABLE, - _AUDIO_TESTING, - _IMAGE_TESTING, _PANDAS_AVAILABLE, - _TABULAR_TESTING, + _TOPIC_AUDIO_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, + _TOPIC_TABULAR_AVAILABLE, Image, ) -if _AUDIO_AVAILABLE: +if _TOPIC_AUDIO_AVAILABLE: import soundfile as sf if _PANDAS_AVAILABLE: @@ -81,7 +80,7 @@ def write_tsv(file_path): ).to_csv(file_path, sep="\t") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize( "extension,write", [(extension, write_image) for extension in IMG_EXTENSIONS] @@ -99,7 +98,7 @@ def test_load_image(tmpdir, extension, write): assert image.mode == "RGB" -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.parametrize( "extension,write", [(extension, write_image) for extension in IMG_EXTENSIONS] @@ -116,7 +115,7 @@ def test_load_spectrogram(tmpdir, extension, write): assert spectrogram.dtype == np.dtype("float32") -@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.parametrize("extension,write", [(extension, write_audio) for extension in AUDIO_EXTENSIONS]) def test_load_audio(tmpdir, extension, write): file_path = os.path.join(tmpdir, f"test{extension}") @@ -128,7 +127,7 @@ def test_load_audio(tmpdir, extension, write): assert audio.dtype == np.dtype("float32") -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( "extension,write", [(extension, write_csv) for extension in CSV_EXTENSIONS] + [(extension, write_tsv) for extension in TSV_EXTENSIONS], @@ -149,26 +148,26 @@ def test_load_data_frame(tmpdir, extension, write): "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", load_image, Image.Image, - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."), + marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed."), ), # it shouldn't try to expand glob patterns in URLs pytest.param( "https://pl-flash-data.s3.amazonaws.com/images/ant_1 [test].jpg", load_image, Image.Image, - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."), + marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed."), ), pytest.param( "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", load_spectrogram, np.ndarray, - marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed."), + marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed."), ), pytest.param( "https://pl-flash-data.s3.amazonaws.com/titanic.csv", load_data_frame, DataFrame, - marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed."), + marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed."), ), ], ) diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py index cff87fed25..84d1a19576 100644 --- a/tests/core/integrations/labelstudio/test_labelstudio.py +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -8,14 +8,19 @@ _load_json_data, ) from flash.core.integrations.labelstudio.visualizer import launch_app -from flash.core.utilities.imports import _CORE_TESTING, _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING +from flash.core.utilities.imports import ( + _TOPIC_CORE_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, + _TOPIC_TEXT_AVAILABLE, + _TOPIC_VIDEO_AVAILABLE, +) from flash.core.utilities.stages import RunningStage from flash.image.classification.data import ImageClassificationData from flash.text.classification.data import TextClassificationData from flash.video.classification.data import LabelStudioVideoClassificationInput, VideoClassificationData -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_utility_load(): """Test for label studio json loader.""" data = [ @@ -139,7 +144,7 @@ def test_utility_load(): assert len(ds_multi[0]) == 5 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_input_labelstudio(): """Test creation of LabelStudioInput.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") @@ -159,7 +164,7 @@ def test_input_labelstudio(): assert val_sample -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_input_labelstudio_image(): """Test creation of LabelStudioImageClassificationInput from images.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data_nofile.zip") @@ -180,7 +185,7 @@ def test_input_labelstudio_image(): assert val_sample -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_datamodule_labelstudio_image(): """Test creation of LabelStudioImageClassificationInput and Datamodule from images.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") @@ -196,7 +201,7 @@ def test_datamodule_labelstudio_image(): assert datamodule -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_label_studio_predictions_visualization(): """Test creation of LabelStudioImageClassificationInput and Datamodule from images.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/data.zip") @@ -222,7 +227,7 @@ def test_label_studio_predictions_visualization(): assert tasks_predictions_json -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_input_labelstudio_text(): """Test creation of LabelStudioTextClassificationInput and Datamodule from text.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") @@ -245,7 +250,7 @@ def test_input_labelstudio_text(): assert len(test) == 0 -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_datamodule_labelstudio_text(): """Test creation of LabelStudioTextClassificationInput and Datamodule from text.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/text_data.zip", "./data/") @@ -257,7 +262,7 @@ def test_datamodule_labelstudio_text(): assert datamodule -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_input_labelstudio_video(): """Test creation of LabelStudioVideoClassificationInput from video.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") @@ -271,7 +276,7 @@ def test_input_labelstudio_video(): assert sample -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_datamodule_labelstudio_video(): """Test creation of Datamodule from video.""" download_data("https://label-studio-testdata.s3.us-east-2.amazonaws.com/lightning-flash/video_data.zip") diff --git a/tests/core/optimizers/test_lr_scheduler.py b/tests/core/optimizers/test_lr_scheduler.py index c8406afe12..b18f9b4a1f 100644 --- a/tests/core/optimizers/test_lr_scheduler.py +++ b/tests/core/optimizers/test_lr_scheduler.py @@ -18,10 +18,10 @@ from torch.optim import Adam from flash.core.optimizers import LinearWarmupCosineAnnealingLR -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "lr, warmup_epochs, max_epochs, warmup_start_lr, eta_min", [ diff --git a/tests/core/optimizers/test_optimizers.py b/tests/core/optimizers/test_optimizers.py index 9b276f28c5..bdfff256ee 100644 --- a/tests/core/optimizers/test_optimizers.py +++ b/tests/core/optimizers/test_optimizers.py @@ -16,10 +16,10 @@ from torch import nn from flash.core.optimizers import LAMB, LARS, LinearWarmupCosineAnnealingLR -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "optim_fn, lr, kwargs", [ @@ -43,7 +43,7 @@ def test_optim_call(tmpdir, optim_fn, lr, kwargs): optimizer.step() -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("optim_fn, lr", [(LARS, 0.1), (LAMB, 1e-3)]) def test_optim_with_scheduler(tmpdir, optim_fn, lr): max_epochs = 10 diff --git a/tests/core/serve/test_compat/test_cached_property.py b/tests/core/serve/test_compat/test_cached_property.py index ffa50b360d..9a856b2b61 100644 --- a/tests/core/serve/test_compat/test_cached_property.py +++ b/tests/core/serve/test_compat/test_cached_property.py @@ -14,7 +14,7 @@ # Package Implementation from flash.core.serve._compat.cached_property import cached_property -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE class CachedCostItem: @@ -78,7 +78,7 @@ def cost(self): # noinspection PyStatementEffect -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") class TestCachedProperty: @staticmethod def test_cached(): @@ -210,7 +210,7 @@ def test_doc(): assert CachedCostItem.cost.__doc__ == "The cost of the item." -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") @pytest.mark.skipif(sys.version_info < (3, 8), reason="Validate, that python 3.8 uses standard implementation") class TestPy38Plus: @staticmethod diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py index ab971a4724..c95652e37f 100644 --- a/tests/core/serve/test_components.py +++ b/tests/core/serve/test_components.py @@ -2,11 +2,11 @@ import torch from flash.core.serve.types import Label -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE from tests.core.serve.models import ClassificationInferenceComposable, LightningSqueezenet -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_model_compute_call_method(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) img = torch.arange(195075).reshape((1, 255, 255, 3)) @@ -15,7 +15,7 @@ def test_model_compute_call_method(lightning_squeezenet1_1_obj): assert out_res.item() == 753 -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_model_compute_dependencies(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -33,7 +33,7 @@ def test_model_compute_dependencies(lightning_squeezenet1_1_obj): assert list(comp2._flashserve_meta_.connections) == [] -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -52,7 +52,7 @@ def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_ob assert list(comp1._flashserve_meta_.connections) == [] -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_two_component_invalid_dependencies_fail(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -86,7 +86,7 @@ def __init__(self): comp1.inputs["tag"] >> foo -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_component_initialization(lightning_squeezenet1_1_obj): with pytest.raises(TypeError): ClassificationInferenceComposable(wrongname=lightning_squeezenet1_1_obj) @@ -101,7 +101,7 @@ def test_component_initialization(lightning_squeezenet1_1_obj): assert "predicted_tag" in comp.outputs -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_component_parameters(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -121,7 +121,7 @@ def test_component_parameters(lightning_squeezenet1_1_obj): assert first_tag.connections == comp1._flashserve_meta_.connections -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_invalid_expose_inputs(): from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number @@ -181,7 +181,7 @@ def predict(param): _ = ComposeClassEmptyExposeInputsType(lr) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_connection_invalid_raises(lightning_squeezenet1_1_obj): comp1 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj) @@ -197,7 +197,7 @@ class FakeParam: comp1.outputs.predicted_tag >> fake_param -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_invalid_name(lightning_squeezenet1_1_obj): from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number @@ -214,7 +214,7 @@ def predict(param): return param -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_invalid_config_args(lightning_squeezenet1_1_obj): from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number @@ -241,7 +241,7 @@ def predict(self, param): _ = SomeComponent(lightning_squeezenet1_1_obj, config={"key": lambda x: x}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_invalid_model_args(lightning_squeezenet1_1_obj): from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number @@ -272,7 +272,7 @@ def predict(param): _ = SomeComponent({"first": lightning_squeezenet1_1_obj, "second": 233}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_create_invalid_endpoint(lightning_squeezenet1_1_obj): from flash.core.serve import Endpoint diff --git a/tests/core/serve/test_composition.py b/tests/core/serve/test_composition.py index c966e1dd04..1152f2c69e 100644 --- a/tests/core/serve/test_composition.py +++ b/tests/core/serve/test_composition.py @@ -4,13 +4,13 @@ import pytest from flash.core.serve import Composition, Endpoint -from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _TOPIC_SERVE_AVAILABLE if _FASTAPI_AVAILABLE: from fastapi.testclient import TestClient -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_composit_endpoint_data(lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -56,7 +56,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj): } -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -136,7 +136,7 @@ def test_endpoint_errors_on_wrong_key_name(lightning_squeezenet1_1_obj): _ = Composition(comp1=comp1, predict_ep=ep) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj): # no endpoints or components with pytest.raises(TypeError): @@ -152,7 +152,7 @@ def test_composition_recieve_wrong_arg_type(lightning_squeezenet1_1_obj): _ = Composition(c1=comp1, c2=comp2) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_servable_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_servable): from tests.core.serve.models import ClassificationInferenceModelSequence @@ -166,7 +166,7 @@ def test_servable_sequence(tmp_path, lightning_squeezenet1_1_obj, squeezenet_ser assert composit.components["callnum_1"].model2 == model_seq[1] -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_servable_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_servable): from tests.core.serve.models import ClassificationInferenceModelMapping @@ -180,7 +180,7 @@ def test_servable_mapping(tmp_path, lightning_squeezenet1_1_obj, squeezenet_serv assert composit.components["callnum_1"].model2 == model_map["model_two"] -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_invalid_servable_composition(tmp_path, lightning_squeezenet1_1_obj, squeezenet_servable): from tests.core.serve.models import ClassificationInferenceModelMapping @@ -194,7 +194,7 @@ def test_invalid_servable_composition(tmp_path, lightning_squeezenet1_1_obj, squ _ = ClassificationInferenceModelMapping(lambda x: x + 1) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -245,7 +245,7 @@ def test_complex_spec_single_endpoint(tmp_path, lightning_squeezenet1_1_obj): } -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_complex_spec_multiple_endpoints(tmp_path, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -321,7 +321,7 @@ def test_complex_spec_multiple_endpoints(tmp_path, lightning_squeezenet1_1_obj): } -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_start_server_from_composition(tmp_path, squeezenet_servable, session_global_datadir): from tests.core.serve.models import ClassificationInferenceComposable diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py index 366b3cdbbb..1c9712d63c 100644 --- a/tests/core/serve/test_dag/test_optimization.py +++ b/tests/core/serve/test_dag/test_optimization.py @@ -16,14 +16,14 @@ from flash.core.serve.dag.task import get, get_dependencies from flash.core.serve.dag.utils import apply, partial_by_order from flash.core.serve.dag.utils_test import add, inc -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE def double(x): return x * 2 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_cull(): # 'out' depends on 'x' and 'y', but not 'z' d = {"x": 1, "y": (inc, "x"), "z": (inc, "x"), "out": (add, "y", 10)} @@ -51,7 +51,7 @@ def with_deps(dsk): return dsk, {k: get_dependencies(dsk, k) for k in dsk} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse(): fuse = fuse2 # tests both `fuse` and `fuse_linear` d = { @@ -164,7 +164,7 @@ def test_fuse(): assert fuse(d, rename_keys=True) == with_deps({"a-b": (inc, 1), "c": (add, "a-b", "a-b"), "b": "a-b"}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse_keys(): fuse = fuse2 # tests both `fuse` and `fuse_linear` d = {"a": 1, "b": (inc, "a"), "c": (inc, "b")} @@ -196,7 +196,7 @@ def test_fuse_keys(): ) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline(): d = {"a": 1, "b": (inc, "a"), "c": (inc, "b"), "d": (add, "a", "c")} assert inline(d) == {"a": 1, "b": (inc, 1), "c": (inc, "b"), "d": (add, 1, "c")} @@ -232,7 +232,7 @@ def test_inline(): } -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_functions(): x, y, i, d = "xyid" dsk = {"out": (add, i, d), i: (inc, x), d: (double, y), x: 1, y: 1} @@ -242,7 +242,7 @@ def test_inline_functions(): assert result == expected -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_ignores_curries_and_partials(): dsk = {"x": 1, "y": 2, "a": (partial(add, 1), "x"), "b": (inc, "a")} @@ -251,7 +251,7 @@ def test_inline_ignores_curries_and_partials(): assert "a" not in result -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_functions_non_hashable(): class NonHashableCallable: def __call__(self, a): @@ -269,14 +269,14 @@ def __hash__(self): assert "b" not in result -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_doesnt_shrink_fast_functions_at_top(): dsk = {"x": (inc, "y"), "y": 1} result = inline_functions(dsk, [], fast_functions={inc}) assert result == dsk -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_traverses_lists(): x, y, i, d = "xyid" dsk = {"out": (sum, [i, d]), i: (inc, x), d: (double, y), x: 1, y: 1} @@ -285,14 +285,14 @@ def test_inline_traverses_lists(): assert result == expected -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_functions_protects_output_keys(): dsk = {"x": (inc, 1), "y": (double, "x")} assert inline_functions(dsk, [], [inc]) == {"y": (double, (inc, 1))} assert inline_functions(dsk, ["x"], [inc]) == {"y": (double, "x"), "x": (inc, 1)} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_functions_of(): def a(x): return x @@ -309,7 +309,7 @@ def b(x): assert functions_of((a,)) == {a} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_inline_cull_dependencies(): d = {"a": 1, "b": "a", "c": "b", "d": ["a", "b", "c"], "e": (add, (len, "d"), "a")} @@ -317,7 +317,7 @@ def test_inline_cull_dependencies(): inline(d2, {"b"}, dependencies=dependencies) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse_reductions_single_input(): def f(*args): return args @@ -901,7 +901,7 @@ def f(*args): ) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse_stressed(): def f(*args): return args @@ -974,7 +974,7 @@ def f(*args): assert rv == with_deps(rv[0]) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse_reductions_multiple_input(): def f(*args): return args @@ -1082,7 +1082,7 @@ def func_with_kwargs(a, b, c=2): return a + b + c -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_SubgraphCallable(): non_hashable = [1, 2, 3] @@ -1129,7 +1129,7 @@ def test_SubgraphCallable(): assert f2(1, 2) == f(1, 2) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_SubgraphCallable_with_numpy(): np = pytest.importorskip("numpy") @@ -1149,7 +1149,7 @@ def test_SubgraphCallable_with_numpy(): assert f1 != f4 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse_subgraphs(): dsk = { "x-1": 1, @@ -1277,7 +1277,7 @@ def test_fuse_subgraphs(): assert res in sols -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fuse_subgraphs_linear_chains_of_duplicate_deps(): dsk = { "x-1": 1, @@ -1311,7 +1311,7 @@ def test_fuse_subgraphs_linear_chains_of_duplicate_deps(): assert res == sol -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_dont_fuse_numpy_arrays(): """Some types should stay in the graph bare This helps with things like serialization.""" np = pytest.importorskip("numpy") @@ -1320,7 +1320,7 @@ def test_dont_fuse_numpy_arrays(): assert fuse(dsk, "y")[0] == dsk -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_fused_keys_max_length(): # generic fix for gh-5999 d = { "u-looooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong": ( diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py index a404d668db..49f9bd5408 100644 --- a/tests/core/serve/test_dag/test_order.py +++ b/tests/core/serve/test_dag/test_order.py @@ -3,7 +3,7 @@ from flash.core.serve.dag.order import ndependencies, order from flash.core.serve.dag.task import get, get_deps from flash.core.serve.dag.utils_test import add, inc -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE @pytest.fixture(params=["abcde", "edcba"]) @@ -19,7 +19,7 @@ def f(*args): pass -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_ordering_keeps_groups_together(abcde): a, b, c, d, e = abcde d = {(a, i): (f,) for i in range(4)} @@ -37,7 +37,7 @@ def test_ordering_keeps_groups_together(abcde): assert abs(o[(a, 1)] - o[(a, 3)]) == 1 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_avoid_broker_nodes(abcde): r"""Testing structure. @@ -82,7 +82,7 @@ def test_avoid_broker_nodes(abcde): assert o[(a, 0)] < o[(a, 1)] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_base_of_reduce_preferred(abcde): r"""Testing structure. @@ -113,7 +113,7 @@ def test_base_of_reduce_preferred(abcde): assert o[(b, 1)] <= 6 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") @pytest.mark.xfail(reason="Can't please 'em all", strict=True) def test_avoid_upwards_branching(abcde): r""" @@ -146,7 +146,7 @@ def test_avoid_upwards_branching(abcde): assert o[(b, 1)] < o[(c, 1)] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_avoid_upwards_branching_complex(abcde): r""" a1 @@ -186,7 +186,7 @@ def test_avoid_upwards_branching_complex(abcde): assert abs(o[(d, 2)] - o[(d, 3)]) == 1 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_deep_bases_win_over_dependents(abcde): r"""It's not clear who should run first, e or d. @@ -210,7 +210,7 @@ def test_deep_bases_win_over_dependents(abcde): assert o[b] < o[c] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_prefer_deep(abcde): """ c @@ -229,14 +229,14 @@ def test_prefer_deep(abcde): assert o[b] < o[d] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_stacklimit(abcde): dsk = {"x%s" % (i + 1): (inc, "x%s" % i) for i in range(10000)} dependencies, dependents = get_deps(dsk) ndependencies(dependencies, dependents) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_break_ties_by_str(abcde): a, b, c, d, e = abcde dsk = {("x", i): (inc, i) for i in range(10)} @@ -250,19 +250,19 @@ def test_break_ties_by_str(abcde): assert o == expected -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_order_doesnt_fail_on_mixed_type_keys(abcde): order({"x": (inc, 1), ("y", 0): (inc, 2), "z": (add, "x", ("y", 0))}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_type_comparisions_ok(abcde): a, b, c, d, e = abcde dsk = {a: 1, (a, 1): 2, (a, b, 1): 3} order(dsk) # this doesn't err -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_prefer_short_dependents(abcde): r""" @@ -283,7 +283,7 @@ def test_prefer_short_dependents(abcde): assert o[e] < o[b] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") @pytest.mark.xfail(reason="This is challenging to do precisely") def test_run_smaller_sections(abcde): r"""Testing structure. @@ -326,7 +326,7 @@ def _(*args): assert log == expected -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_local_parents_of_reduction(abcde): """ @@ -375,7 +375,7 @@ def _(*args): assert log == expected -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_nearest_neighbor(abcde): r"""Testing structure. @@ -413,7 +413,7 @@ def test_nearest_neighbor(abcde): assert o[min([b1, b2, b3, b4])] == 0 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_string_ordering(): """Prefer ordering tasks by name first.""" dsk = {("a", 1): (f,), ("a", 2): (f,), ("a", 3): (f,)} @@ -421,7 +421,7 @@ def test_string_ordering(): assert o == {("a", 1): 0, ("a", 2): 1, ("a", 3): 2} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_string_ordering_dependents(): """Prefer ordering tasks by name first even when in dependencies.""" dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f,)} @@ -429,7 +429,7 @@ def test_string_ordering_dependents(): assert o == {"b": 0, ("a", 1): 1, ("a", 2): 2, ("a", 3): 3} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_prefer_short_narrow(abcde): # See test_prefer_short_ancestor for a fail case. a, b, c, _, _ = abcde @@ -448,7 +448,7 @@ def test_prefer_short_narrow(abcde): assert o[(c, 1)] < o[(c, 2)] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_prefer_short_ancestor(abcde): r"""From https://github.com/dask/dask-ml/issues/206#issuecomment-395869929. @@ -508,7 +508,7 @@ def test_prefer_short_ancestor(abcde): assert o[(c, 1)] < o[(a, 1)] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_map_overlap(abcde): r"""Testing structure. @@ -548,7 +548,7 @@ def test_map_overlap(abcde): assert o[(b, 1)] < o[(e, 5)] or o[(b, 5)] < o[(e, 1)] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_use_structure_not_keys(abcde): """See https://github.com/dask/dask/issues/5584#issuecomment-554963958. @@ -589,7 +589,7 @@ def test_use_structure_not_keys(abcde): assert Bs == [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_dont_run_all_dependents_too_early(abcde): """From https://github.com/dask/dask-ml/issues/206#issuecomment-395873372.""" a, b, c, d, e = abcde @@ -605,7 +605,7 @@ def test_dont_run_all_dependents_too_early(abcde): assert expected == actual -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_many_branches_use_ndependencies(abcde): """From https://github.com/dask/dask/pull/5646#issuecomment-562700533. @@ -642,7 +642,7 @@ def test_many_branches_use_ndependencies(abcde): assert o[(c, 1)] == o[(a, 3)] - 1 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_order_cycle(): with pytest.raises(RuntimeError, match="Cycle detected"): get({"a": (f, "a")}, "a") # we encounter this in `get` @@ -658,12 +658,12 @@ def test_order_cycle(): order({"a": (f, "b"), "b": (f, "c"), "c": (f, "a", "d"), "d": (f, "b")}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_order_empty(): assert order({}) == {} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_switching_dependents(abcde): r"""Testing structure. @@ -719,7 +719,7 @@ def test_switching_dependents(abcde): assert o[(a, 5)] > o[(e, 6)] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_order_with_equal_dependents(abcde): """From https://github.com/dask/dask/issues/5859#issuecomment-608422198. diff --git a/tests/core/serve/test_dag/test_rewrite.py b/tests/core/serve/test_dag/test_rewrite.py index 2bfbae013a..1e1bfa8a35 100644 --- a/tests/core/serve/test_dag/test_rewrite.py +++ b/tests/core/serve/test_dag/test_rewrite.py @@ -1,7 +1,7 @@ import pytest from flash.core.serve.dag.rewrite import VAR, RewriteRule, RuleSet, Traverser, args, head -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE def inc(x): @@ -16,7 +16,7 @@ def double(x): return x * 2 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_head(): assert head((inc, 1)) == inc assert head((add, 1, 2)) == add @@ -24,7 +24,7 @@ def test_head(): assert head([1, 2, 3]) == list -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_args(): assert args((inc, 1)) == (1,) assert args((add, 1, 2)) == (1, 2) @@ -32,7 +32,7 @@ def test_args(): assert args([1, 2, 3]) == [1, 2, 3] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_traverser(): term = (add, (inc, 1), (double, (inc, 1), 2)) t = Traverser(term) @@ -74,7 +74,7 @@ def repl_list(sd): rule6 = RewriteRule((list, "x"), repl_list, ("x",)) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_RewriteRule(): # Test extraneous vars are removed, varlist is correct assert rule1.vars == ("a",) @@ -89,7 +89,7 @@ def test_RewriteRule(): assert rule5._varlist == ["c", "b", "a"] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_RewriteRuleSubs(): # Test both rhs substitution and callable rhs assert rule1.subs({"a": 1}) == (inc, 1) @@ -100,7 +100,7 @@ def test_RewriteRuleSubs(): rs = RuleSet(*rules) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_RuleSet(): net = ( { @@ -120,7 +120,7 @@ def test_RuleSet(): assert rs.rules == rules -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_matches(): term = (add, 2, 1) matches = list(rs.iter_matches(term)) @@ -151,7 +151,7 @@ def test_matches(): assert len(matches) == 0 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_rewrite(): # Rewrite inside list term = (sum, [(add, 1, 1), (add, 1, 1), (add, 1, 1)]) diff --git a/tests/core/serve/test_dag/test_task.py b/tests/core/serve/test_dag/test_task.py index 3d5fc14e75..49c69ade50 100644 --- a/tests/core/serve/test_dag/test_task.py +++ b/tests/core/serve/test_dag/test_task.py @@ -15,7 +15,7 @@ subs, ) from flash.core.serve.dag.utils_test import add, inc -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE def contains(a, b): @@ -28,7 +28,7 @@ def contains(a, b): return all(a.get(k) == v for k, v in b.items()) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_istask(): assert istask((inc, 1)) assert not istask(1) @@ -37,7 +37,7 @@ def test_istask(): assert not istask(f(sum, 2)) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_preorder_traversal(): t = (add, 1, 2) assert list(preorder_traversal(t)) == [add, 1, 2] @@ -47,7 +47,7 @@ def test_preorder_traversal(): assert list(preorder_traversal(t)) == [add, sum, list, 1, 2, 3] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_nested(): dsk = {"x": 1, "y": 2, "z": (add, (inc, [["x"]]), "y")} @@ -55,34 +55,34 @@ def test_get_dependencies_nested(): assert sorted(get_dependencies(dsk, "z", as_list=True)) == ["x", "y"] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_empty(): dsk = {"x": (inc,)} assert get_dependencies(dsk, "x") == set() assert get_dependencies(dsk, "x", as_list=True) == [] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_list(): dsk = {"x": 1, "y": 2, "z": ["x", [(inc, "y")]]} assert get_dependencies(dsk, "z") == {"x", "y"} assert sorted(get_dependencies(dsk, "z", as_list=True)) == ["x", "y"] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_task(): dsk = {"x": 1, "y": 2, "z": ["x", [(inc, "y")]]} assert get_dependencies(dsk, task=(inc, "x")) == {"x"} assert get_dependencies(dsk, task=(inc, "x"), as_list=True) == ["x"] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_nothing(): with pytest.raises(ValueError): get_dependencies({}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_many(): dsk = { "a": [1, 2, 3], @@ -105,14 +105,14 @@ def test_get_dependencies_many(): assert s == [] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_dependencies_task_none(): # Regression test for https://github.com/dask/distributed/issues/2756 dsk = {"foo": None} assert get_dependencies(dsk, task=dsk["foo"]) == set() -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_get_deps(): """ >>> dsk = {'a': 1, 'b': (inc, 'a'), 'c': (inc, 'b')} @@ -149,13 +149,13 @@ def test_get_deps(): } -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_flatten(): assert list(flatten(())) == [] assert list(flatten("foo")) == ["foo"] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_subs(): assert subs((sum, [1, "x"]), "x", 2) == (sum, [1, 2]) assert subs((sum, [1, ["x"]]), "x", 2) == (sum, [1, [2]]) @@ -169,7 +169,7 @@ def __eq__(self, other): return False -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_subs_no_key_data_eq(): # Numpy throws a deprecation warning on bool(array == scalar), which # pollutes the terminal. This test checks that `subs` never tries to @@ -182,7 +182,7 @@ def test_subs_no_key_data_eq(): assert a.hit_eq == 0 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_subs_with_unfriendly_eq(): try: import numpy as np @@ -203,7 +203,7 @@ def __eq__(self, other): assert subs(task, 1, 2) is task -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_subs_with_surprisingly_friendly_eq(): try: import pandas as pd @@ -214,7 +214,7 @@ def test_subs_with_surprisingly_friendly_eq(): assert subs(df, "x", 1) is df -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_subs_unexpected_hashable_key(): class UnexpectedButHashable: def __init__(self): @@ -229,7 +229,7 @@ def __eq__(self, other): assert subs((id, UnexpectedButHashable()), UnexpectedButHashable(), 1) == (id, 1) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_quote(): literals = [[1, 2, 3], (add, 1, 2), [1, [2, 3]], (add, 1, (add, 2, 3)), {"x": "x"}] @@ -237,7 +237,7 @@ def test_quote(): assert get({"x": quote(le)}, "x") == le -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_literal_serializable(): le = literal((add, 1, 2)) assert pickle.loads(pickle.dumps(le)).data == (add, 1, 2) diff --git a/tests/core/serve/test_dag/test_utils.py b/tests/core/serve/test_dag/test_utils.py index a242d1af50..f3faf47606 100644 --- a/tests/core/serve/test_dag/test_utils.py +++ b/tests/core/serve/test_dag/test_utils.py @@ -5,13 +5,13 @@ import pytest from flash.core.serve.dag.utils import funcname, partial_by_order -from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _TOPIC_SERVE_AVAILABLE if _CYTOOLZ_AVAILABLE: from cytoolz import curry -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_funcname_long(): def a_long_function_name_11111111111111111111111111111111111111111111111(): pass @@ -21,7 +21,7 @@ def a_long_function_name_11111111111111111111111111111111111111111111111(): assert len(result) < 60 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_funcname_cytoolz(): @curry def foo(a, b, c): @@ -37,12 +37,12 @@ def bar(a, b): assert funcname(c_bar) == "bar" -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_partial_by_order(): assert partial_by_order(5, function=operator.add, other=[(1, 20)]) == 25 -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_funcname(): assert funcname(np.floor_divide) == "floor_divide" assert funcname(partial(bool)) == "bool" @@ -50,7 +50,7 @@ def test_funcname(): assert funcname(lambda x: x) == "lambda" -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_numpy_vectorize_funcname(): def myfunc(a, b): """Return a-b if a>b, otherwise return a+b.""" diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py index ab8c29fa08..c343d95622 100644 --- a/tests/core/serve/test_gridbase_validations.py +++ b/tests/core/serve/test_gridbase_validations.py @@ -2,7 +2,7 @@ from flash.core.serve import ModelComponent, expose from flash.core.serve.types import Number -from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _CYTOOLZ_AVAILABLE, _TOPIC_SERVE_AVAILABLE @pytest.mark.skipif(not _CYTOOLZ_AVAILABLE, reason="the library cytoolz is not installed.") @@ -168,7 +168,7 @@ def predict(param): return param -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_method_parameters( lightning_squeezenet1_1_obj, ): @@ -193,7 +193,7 @@ def predict(self, param): _ = FailedExposedDecorator(comp) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve is not installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve is not installed.") def test_ModelComponent_raises_if_config_is_empty_dict(lightning_squeezenet1_1_obj): """This occurs when the instance is being initialized. diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py index 8d5063fa14..d54d2dab60 100644 --- a/tests/core/serve/test_integration.py +++ b/tests/core/serve/test_integration.py @@ -3,9 +3,9 @@ import pytest from flash.core.serve import Composition, Endpoint -from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _SERVE_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _FASTAPI_AVAILABLE, _TOPIC_SERVE_AVAILABLE -if _SERVE_AVAILABLE: +if _TOPIC_SERVE_AVAILABLE: from jinja2 import TemplateNotFound else: TemplateNotFound = ... @@ -14,7 +14,7 @@ from fastapi.testclient import TestClient -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_resnet_18_inference_class(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference @@ -42,7 +42,7 @@ def test_resnet_18_inference_class(session_global_datadir, lightning_squeezenet1 assert expected == resp.json() -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_start_server_with_repeated_exposed(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceRepeated @@ -67,7 +67,7 @@ def test_start_server_with_repeated_exposed(session_global_datadir, lightning_sq assert resp.json() == expected -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_serving_single_component_and_endpoint_no_composition(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference @@ -158,7 +158,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat } -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @pytest.mark.xfail(TemplateNotFound, reason="jinja2.exceptions.TemplateNotFound: dag.html") # todo def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -213,7 +213,7 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj): assert resp.template.name == "dag.html" -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @pytest.mark.xfail(TemplateNotFound, reason="jinja2.exceptions.TemplateNotFound: dag.html") # todo def test_composed_does_not_eliminate_endpoint_serialization(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -289,7 +289,7 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad assert resp.template.name == "dag.html" -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @pytest.mark.xfail(TemplateNotFound, reason="jinja2.exceptions.TemplateNotFound: dag.html") # todo def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInference, SeatClassifier @@ -402,7 +402,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ } -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1_1_obj): from tests.core.serve.models import ClassificationInferenceComposable @@ -412,7 +412,7 @@ def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1 c1.outputs.cropped_img >> c1.inputs.img -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") def test_composition_from_url_torchscript_servable(tmp_path): from flash.core.serve import ModelComponent, Servable, expose from flash.core.serve.types import Number diff --git a/tests/core/serve/test_types/test_bbox.py b/tests/core/serve/test_types/test_bbox.py index 48eeca3896..3fe9e273b6 100644 --- a/tests/core/serve/test_types/test_bbox.py +++ b/tests/core/serve/test_types/test_bbox.py @@ -2,10 +2,10 @@ import torch from flash.core.serve.types import BBox -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_deserialize(): bbox = BBox() assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4,))) @@ -34,7 +34,7 @@ def test_deserialize(): ) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_serialize(): bbox = BBox() assert bbox.serialize(torch.ones(4)) == [1.0, 1.0, 1.0, 1.0] diff --git a/tests/core/serve/test_types/test_image.py b/tests/core/serve/test_types/test_image.py index 7470218e26..d96dc7671a 100644 --- a/tests/core/serve/test_types/test_image.py +++ b/tests/core/serve/test_types/test_image.py @@ -5,10 +5,10 @@ from torch import Tensor from flash.core.serve.types import Image -from flash.core.utilities.imports import _PIL_AVAILABLE, _SERVE_TESTING +from flash.core.utilities.imports import _PIL_AVAILABLE, _TOPIC_SERVE_AVAILABLE -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") @pytest.mark.skipif(not _PIL_AVAILABLE, reason="library PIL is not installed.") def test_deserialize_serialize(session_global_datadir): with (session_global_datadir / "cat.jpg").open("rb") as f: diff --git a/tests/core/serve/test_types/test_label.py b/tests/core/serve/test_types/test_label.py index fb3b324372..72c037e6fb 100644 --- a/tests/core/serve/test_types/test_label.py +++ b/tests/core/serve/test_types/test_label.py @@ -2,23 +2,23 @@ import torch from flash.core.serve.types import Label -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_path(session_global_datadir): label = Label(path=str(session_global_datadir / "imagenet_labels.txt")) assert label.deserialize("chickadee") == torch.tensor(19) assert label.serialize(torch.tensor(19)) == "chickadee" -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_list(): label = Label(classes=["classA", "classB"]) assert label.deserialize("classA") == torch.tensor(0) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_dict(): label = Label(classes={56: "classA", 48: "classB"}) assert label.deserialize("classA") == torch.tensor(56) @@ -27,7 +27,7 @@ def test_dict(): Label(classes={"wrongtype": "classA"}) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_wrong_type(): with pytest.raises(TypeError): Label(classes=set()) diff --git a/tests/core/serve/test_types/test_number.py b/tests/core/serve/test_types/test_number.py index 39a375ae38..2821fdc3d7 100644 --- a/tests/core/serve/test_types/test_number.py +++ b/tests/core/serve/test_types/test_number.py @@ -2,10 +2,10 @@ import torch from flash.core.serve.types import Number -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_serialize(): num = Number() tensor = torch.tensor([[1]]) @@ -23,7 +23,7 @@ def test_serialize(): num.serialize(tensor) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_deserialize(): num = Number() assert num.deserialize(1).shape == torch.Size([1, 1]) diff --git a/tests/core/serve/test_types/test_repeated.py b/tests/core/serve/test_types/test_repeated.py index 16132543a0..a7d7b035b8 100644 --- a/tests/core/serve/test_types/test_repeated.py +++ b/tests/core/serve/test_types/test_repeated.py @@ -2,17 +2,17 @@ import torch from flash.core.serve.types import Label, Repeated -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_repeated_deserialize(): repeated = Repeated(dtype=Label(classes=["classA", "classB"])) res = repeated.deserialize(*({"label": "classA"}, {"label": "classA"}, {"label": "classB"})) assert res == (torch.tensor(0), torch.tensor(0), torch.tensor(1)) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_repeated_serialize(session_global_datadir): repeated = Repeated(dtype=Label(path=str(session_global_datadir / "imagenet_labels.txt"))) assert repeated.deserialize(*({"label": "chickadee"}, {"label": "stingray"})) == ( @@ -23,7 +23,7 @@ def test_repeated_serialize(session_global_datadir): assert repeated.serialize(torch.tensor([19, 6])) == ("chickadee", "stingray") -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_repeated_max_len(): repeated = Repeated(dtype=Label(classes=["classA", "classB"]), max_len=2) @@ -47,7 +47,7 @@ def test_repeated_max_len(): Repeated(dtype=Label(classes=["classA", "classB"]), max_len=str) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_repeated_non_serve_dtype(): class NonServeDtype: pass @@ -56,7 +56,7 @@ class NonServeDtype: Repeated(NonServeDtype()) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_not_allow_nested_repeated(): with pytest.raises(TypeError): Repeated(dtype=Repeated()) diff --git a/tests/core/serve/test_types/test_table.py b/tests/core/serve/test_types/test_table.py index 7d6f939664..78eb96d1bf 100644 --- a/tests/core/serve/test_types/test_table.py +++ b/tests/core/serve/test_types/test_table.py @@ -2,7 +2,7 @@ import torch from flash.core.serve.types import Table -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE data = torch.tensor([[0.00632, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98]]) feature_names = [ @@ -22,7 +22,7 @@ ] -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_serialize_success(): table = Table(column_names=feature_names) sample = data @@ -31,7 +31,7 @@ def test_serialize_success(): assert d2 == {0: d1.item()} -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_serialize_wrong_shape(): table = Table(column_names=feature_names) sample = data.squeeze() @@ -50,7 +50,7 @@ def test_serialize_wrong_shape(): table.serialize(sample) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_serialize_without_column_names(): with pytest.raises(TypeError): Table() @@ -60,7 +60,7 @@ def test_serialize_without_column_names(): assert list(dict_data.keys()) == feature_names -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_deserialize(): arr = torch.tensor([100, 200]).view(1, 2) table = Table(column_names=["t1", "t2"]) @@ -76,7 +76,7 @@ def test_deserialize(): ) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_deserialize_column_names_failures(): table = Table(["t1", "t2"]) with pytest.raises(RuntimeError): diff --git a/tests/core/serve/test_types/test_text.py b/tests/core/serve/test_types/test_text.py index fd034bd7d6..4aa3b7afd4 100644 --- a/tests/core/serve/test_types/test_text.py +++ b/tests/core/serve/test_types/test_text.py @@ -3,7 +3,7 @@ import pytest import torch -from flash.core.utilities.imports import _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE @dataclass @@ -17,7 +17,7 @@ def decode(self, tensor): return f"decoding from {self.name}" -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_custom_tokenizer(): from flash.core.serve.types import Text @@ -27,7 +27,7 @@ def test_custom_tokenizer(): assert "decoding from test" == text.serialize(torch.tensor([[1, 2]])) -@pytest.mark.skipif(not _SERVE_TESTING, reason="Not testing serve.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") def test_tokenizer_string(): from flash.core.serve.types import Text diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index df7b69d16b..dff8750cd2 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -22,10 +22,10 @@ ProbabilitiesOutput, ) from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _CORE_TESTING, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_CORE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_classification_outputs(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes labels = ["class_1", "class_2", "class_3"] @@ -38,7 +38,7 @@ def test_classification_outputs(): assert LabelsOutput(labels).transform(example_output) == "class_3" -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_classification_outputs_multi_label(): example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes labels = ["class_1", "class_2", "class_3"] @@ -52,7 +52,7 @@ def test_classification_outputs_multi_label(): assert LabelsOutput(labels, multi_label=True).transform(example_output) == ["class_2", "class_3"] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_classification_outputs_fiftyone(): logits = torch.tensor([-0.1, 0.2, 0.3]) diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 9c711de50c..522c7224e8 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -16,7 +16,7 @@ from flash import DataKeys, DataModule, RunningStage from flash.core.data.data_module import DatasetInput -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # ======== Mock functions ======== @@ -32,7 +32,7 @@ def __len__(self) -> int: # =============================== -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_init(): train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset()) @@ -48,7 +48,7 @@ def test_init(): assert data_module.train_dataset and data_module.val_dataset and data_module.test_dataset -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_dataloaders(): train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset()) diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 662eddc64a..72115158fb 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -26,7 +26,7 @@ import flash from flash.core.finetuning import _FINETUNING_STRATEGIES_REGISTRY from flash.core.model import Task -from flash.core.utilities.imports import _CORE_TESTING, _DEEPSPEED_AVAILABLE +from flash.core.utilities.imports import _DEEPSPEED_AVAILABLE, _TOPIC_CORE_AVAILABLE from tests.helpers.boring_model import BoringModel @@ -135,7 +135,7 @@ def on_train_epoch_start(self, trainer, pl_module): assert pl_module.model.layer.weight.requires_grad -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "strategy, plugins", [ @@ -173,7 +173,7 @@ def test_finetuning_with_none_return_type(strategy, plugins): trainer.finetune(task, train_dataloader=DataLoader(ds), strategy=strategy) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( ("strategy", "lr_scheduler", "checker_class", "checker_class_data"), [ @@ -220,7 +220,7 @@ def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class trainer.finetune(task, train_dataloader=DataLoader(ds), strategy=strategy) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "strategy,error", [ diff --git a/tests/core/test_model.py b/tests/core/test_model.py index c60cf1c9ad..489d100662 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -39,14 +39,14 @@ from flash.core.data.io.output_transform import OutputTransform from flash.core.utilities.embedder import Embedder from flash.core.utilities.imports import ( - _AUDIO_TESTING, - _CORE_TESTING, - _GRAPH_TESTING, - _IMAGE_AVAILABLE, - _IMAGE_TESTING, _PL_GREATER_EQUAL_1_8_0, - _TABULAR_TESTING, - _TEXT_TESTING, + _TM_GREATER_EQUAL_0_10_0, + _TOPIC_AUDIO_AVAILABLE, + _TOPIC_CORE_AVAILABLE, + _TOPIC_GRAPH_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, + _TOPIC_TABULAR_AVAILABLE, + _TOPIC_TEXT_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, ) @@ -158,7 +158,7 @@ def __init__(self, child): # ================================ -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("metrics", [None, Accuracy(), {"accuracy": Accuracy()}]) def test_classificationtask_train(tmpdir: str, metrics: Any): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) @@ -171,7 +171,7 @@ def test_classificationtask_train(tmpdir: str, metrics: Any): assert "test_nll_loss" in result[0] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_task_predict_raises(): with pytest.raises(AttributeError, match="`flash.Task.predict` has been removed."): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) @@ -179,7 +179,7 @@ def test_task_predict_raises(): task.predict("args", kwarg="test") -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent]) def test_nested_tasks(tmpdir, task): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) @@ -195,7 +195,7 @@ def test_nested_tasks(tmpdir, task): assert "test_nll_loss" in result[0] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_classification_task_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) @@ -214,7 +214,7 @@ def test_classification_task_trainer_predict(tmpdir): ImageClassifier, "0.7.0/image_classification_model.pt", marks=pytest.mark.skipif( - not _IMAGE_TESTING, + not _TOPIC_IMAGE_AVAILABLE, reason="image packages aren't installed", ), ), @@ -222,7 +222,7 @@ def test_classification_task_trainer_predict(tmpdir): SemanticSegmentation, "0.9.0/semantic_segmentation_model.pt", marks=pytest.mark.skipif( - not _IMAGE_TESTING, + not _TOPIC_IMAGE_AVAILABLE or not _TM_GREATER_EQUAL_0_10_0, reason="image packages aren't installed", ), ), @@ -230,7 +230,7 @@ def test_classification_task_trainer_predict(tmpdir): SpeechRecognition, "0.7.0/speech_recognition_model.pt", marks=pytest.mark.skipif( - not _AUDIO_TESTING, + not _TOPIC_AUDIO_AVAILABLE, reason="audio packages aren't installed", ), ), @@ -238,7 +238,7 @@ def test_classification_task_trainer_predict(tmpdir): TabularClassifier, "0.7.0/tabular_classification_model.pt", marks=pytest.mark.skipif( - not _TABULAR_TESTING, + not _TOPIC_TABULAR_AVAILABLE, reason="tabular packages aren't installed", ), ), @@ -246,7 +246,7 @@ def test_classification_task_trainer_predict(tmpdir): TextClassifier, "0.9.0/text_classification_model.pt", marks=pytest.mark.skipif( - not _TEXT_TESTING, + not _TOPIC_TEXT_AVAILABLE, reason="text packages aren't installed", ), ), @@ -254,7 +254,7 @@ def test_classification_task_trainer_predict(tmpdir): SummarizationTask, "0.7.0/summarization_model_xsum.pt", marks=pytest.mark.skipif( - not _TEXT_TESTING, + not _TOPIC_TEXT_AVAILABLE, reason="text packages aren't installed", ), ), @@ -262,7 +262,7 @@ def test_classification_task_trainer_predict(tmpdir): TranslationTask, "0.7.0/translation_model_en_ro.pt", marks=pytest.mark.skipif( - not _TEXT_TESTING, + not _TOPIC_TEXT_AVAILABLE, reason="text packages aren't installed", ), ), @@ -270,7 +270,7 @@ def test_classification_task_trainer_predict(tmpdir): GraphClassifier, "0.7.0/graph_classification_model.pt", marks=pytest.mark.skipif( - not _GRAPH_TESTING, + not _TOPIC_GRAPH_AVAILABLE, reason="graph packages aren't installed", ), ), @@ -278,7 +278,7 @@ def test_classification_task_trainer_predict(tmpdir): GraphEmbedder, "0.7.0/graph_classification_model.pt", marks=pytest.mark.skipif( - not _GRAPH_TESTING, + not _TOPIC_GRAPH_AVAILABLE, reason="graph packages aren't installed", ), ), @@ -305,7 +305,7 @@ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): return self.backbone(batch) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_as_embedder(): layer_number = 1 embedder = DummyTask().as_embedder(f"backbone.{layer_number}") @@ -314,13 +314,13 @@ def test_as_embedder(): assert embedder.predict_step(torch.rand(10, 10), 0, 0).size(1) == embedder.model.backbone[layer_number].out_features -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_available_layers(): task = DummyTask() assert task.available_layers() == ["output", "", "backbone", "backbone.0", "backbone.1", "backbone.2"] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_available_backbones(): backbones = ImageClassifier.available_backbones() assert "resnet152" in backbones @@ -331,7 +331,7 @@ class Foo(ImageClassifier): assert Foo.available_backbones() is None -@pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.") +@pytest.mark.skipif(_TOPIC_IMAGE_AVAILABLE, reason="image libraries are installed.") def test_available_backbones_raises(): with pytest.raises(ModuleNotFoundError, match="Required dependencies not available."): _ = ImageClassifier.available_backbones() @@ -356,7 +356,7 @@ def custom_steplr_configuration_return_as_dict(optimizer): } -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5), ("Adadelta", {"eps": 0.5})] ) @@ -396,7 +396,7 @@ def test_optimizers_and_schedulers(tmpdir, optim, sched, interval): trainer.fit(task, train_dl) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_optimizer_learning_rate(): mock_optimizer = MagicMock() Task.optimizers_registry(mock_optimizer, "test") @@ -475,7 +475,7 @@ def train_dataloader(self): assert isinstance(trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.LambdaLR) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_manual_optimization(tmpdir): class ManualOptimizationTask(Task): def __init__(self, *args, **kwargs): @@ -505,7 +505,7 @@ def training_step(self, batch: Any, batch_idx: int) -> Any: trainer.fit(task, train_dl, val_dl) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_errors_and_exceptions_optimizers_and_schedulers(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) @@ -542,7 +542,7 @@ def test_errors_and_exceptions_optimizers_and_schedulers(): task.configure_optimizers() -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_classification_task_metrics(): train_dataset = FixedDataset([0, 1]) val_dataset = FixedDataset([1, 1]) @@ -566,7 +566,7 @@ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> trainer.test(task, DataLoader(test_dataset)) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_loss_fn_buffer(): weight = torch.rand(10) model = Task(loss_fn=nn.CrossEntropyLoss(weight=weight)) diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py index a8f9f3f45c..031b2dbfb9 100644 --- a/tests/core/test_registry.py +++ b/tests/core/test_registry.py @@ -17,10 +17,10 @@ from torch import nn from flash.core.registry import ConcatRegistry, ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_registry_raises(): backbones = FlashRegistry("backbones") @@ -47,7 +47,7 @@ def my_model(nc_input=5, nc_output=6): backbones(name=float) # noqa -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_registry(): backbones = FlashRegistry("backbones") @@ -112,7 +112,7 @@ def my_model(): assert "bar" in backbones -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_external_registry(): def getter(key: str): return key @@ -130,7 +130,7 @@ def getter(key: str): assert len(registry.available_keys()) == 0 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_concat_registry(): registry_1 = FlashRegistry("backbones") registry_2 = FlashRegistry("backbones") diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index 57aaeb386f..928f990737 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -26,7 +26,7 @@ from flash import Trainer from flash.core.classification import ClassificationTask -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE class DummyDataset(torch.utils.data.Dataset): @@ -67,7 +67,7 @@ def finetune_function( pass -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("callbacks, should_warn", [([], False), ([NoFreeze()], True)]) def test_trainer_fit(tmpdir, callbacks, should_warn): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) @@ -83,7 +83,7 @@ def test_trainer_fit(tmpdir, callbacks, should_warn): trainer.fit(task, train_dl, val_dl) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_trainer_finetune(tmpdir): model = DummyClassifier() train_dl = DataLoader(DummyDataset()) @@ -93,7 +93,7 @@ def test_trainer_finetune(tmpdir): trainer.finetune(task, train_dl, val_dl, strategy=NoFreeze()) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_resolve_callbacks_invalid_strategy(tmpdir): model = DummyClassifier() trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) @@ -111,7 +111,7 @@ def configure_finetune_callback( return [NoFreeze(), NoFreeze()] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_resolve_callbacks_multi_error(tmpdir): model = DummyClassifier() trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) @@ -120,7 +120,7 @@ def test_resolve_callbacks_multi_error(tmpdir): trainer._resolve_callbacks(task, None) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_resolve_callbacks_override_warning(tmpdir): model = DummyClassifier() trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) @@ -129,7 +129,7 @@ def test_resolve_callbacks_override_warning(tmpdir): trainer._resolve_callbacks(task, strategy="no_freeze") -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_add_argparse_args(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) @@ -137,7 +137,7 @@ def test_add_argparse_args(): assert args.gpus == 1 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_from_argparse_args(): parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 13d9a5a43e..8bc2abc8d1 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -17,7 +17,7 @@ from flash.core.data.utils import download_data from flash.core.utilities.apply_func import get_callable_dict, get_callable_name -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE # ======== Mock functions ======== @@ -34,14 +34,14 @@ def b(): # ============================== -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_get_callable_name(): assert get_callable_name(A()) == "a" assert get_callable_name(b) == "b" assert get_callable_name(lambda: True) == "" -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_get_callable_dict(): d = get_callable_dict(A()) assert type(d["a"]) is A @@ -55,7 +55,7 @@ def test_get_callable_dict(): assert d["two"] == b -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("file", ["titanic.zip", "titanic.tar.gz", "titanic.tar.bz2"]) def test_download_data(tmpdir, file): download_path = "https://pl-flash-data.s3.amazonaws.com/" diff --git a/tests/core/utilities/test_embedder.py b/tests/core/utilities/test_embedder.py index f75d3d5d6f..89abd1c5f9 100644 --- a/tests/core/utilities/test_embedder.py +++ b/tests/core/utilities/test_embedder.py @@ -19,7 +19,7 @@ from torch import nn from flash.core.utilities.embedder import Embedder -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE class EmbedderTestModel(LightningModule): @@ -37,7 +37,7 @@ def __init__(self, n_layers): super().__init__(nn.Sequential(*[nn.Linear(1000, 1000) for _ in range(n_layers)])) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("layer, size", [("backbone.1", 30), ("output", 40), ("", 40)]) def test_embedder(layer, size): """Tests that the embedder ``predict_step`` correctly returns the output from the requested layer.""" @@ -55,7 +55,7 @@ def test_embedder(layer, size): assert embedder(torch.rand(10, 10)).size(1) == size -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_embedder_scaling_overhead(): """Tests that embedding to the 3rd layer of a 200 layer model takes less than double the time of embedding to. @@ -80,10 +80,11 @@ def test_embedder_scaling_overhead(): deep_time = end - start - assert (abs(deep_time - shallow_time) / shallow_time) < 1 + diff_time = abs(deep_time - shallow_time) + assert (diff_time / shallow_time) < 2 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_embedder_raising_overhead(): """Tests that embedding to the output layer of a 3 layer model takes less than 10ms more than the time taken to execute the model without the embedder. @@ -105,4 +106,4 @@ def test_embedder_raising_overhead(): embedder_time = end - start - assert abs(embedder_time - model_time) < 0.01 + assert abs(embedder_time - model_time) < 0.05 diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index b1eee81e7e..f4a4c2493a 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -22,9 +22,9 @@ from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import ( - _CORE_TESTING, _PL_GREATER_EQUAL_1_4_0, _PL_GREATER_EQUAL_1_6_0, + _TOPIC_CORE_AVAILABLE, _TORCHVISION_AVAILABLE, ) from flash.core.utilities.lightning_cli import ( @@ -40,7 +40,7 @@ torchvision_version = version.parse(__import__("torchvision").__version__) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @mock.patch("argparse.ArgumentParser.parse_args") def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer.""" @@ -56,7 +56,7 @@ def test_default_args(mock_argparse, tmpdir): assert trainer.max_epochs == 5 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" @@ -78,7 +78,7 @@ def test_add_argparse_args_redefined(cli_args): assert isinstance(trainer, Trainer) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( ["cli_args", "expected"], [ @@ -108,7 +108,7 @@ def test_parse_args_parsing(cli_args, expected): assert Trainer.from_argparse_args(args) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( ["cli_args", "expected", "instantiate"], [ @@ -130,7 +130,7 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): assert Trainer.from_argparse_args(args) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( ["cli_args", "expected_gpu"], [ @@ -152,7 +152,7 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): assert trainer.data_parallel_device_ids == expected_gpu -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.skipif( sys.version_info < (3, 7), reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", @@ -197,7 +197,7 @@ def trainer_builder( return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (trainer_builder, model_builder)]) def test_lightning_cli(trainer_class, model_class, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" @@ -230,7 +230,7 @@ def on_train_start(callback, trainer, _): assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ dict( @@ -257,7 +257,7 @@ def on_fit_start(self): assert cli.trainer.ran_asserts -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_configurable_callbacks(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): @@ -277,7 +277,7 @@ def add_arguments_to_parser(self, parser): assert callback[0].logging_interval == "epoch" -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.skipif(_PL_GREATER_EQUAL_1_6_0, reason="Bugs in PL >= 1.6.0") def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] @@ -294,7 +294,7 @@ def on_fit_start(self): assert cli.trainer.ran_asserts -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_args(tmpdir): cli_args = [ f"--data.data_dir={tmpdir}", @@ -317,7 +317,7 @@ def test_lightning_cli_args(tmpdir): assert config["trainer"] == cli.config["trainer"] -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_save_config_cases(tmpdir): config_path = tmpdir / "config.yaml" cli_args = [ @@ -342,7 +342,7 @@ def test_lightning_cli_save_config_cases(tmpdir): LightningCLI(BoringModel) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict( model=dict(class_path="tests.helpers.boring_model.BoringModel"), @@ -380,7 +380,7 @@ def any_model_any_data_cli(): ) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_help(): cli_args = ["any.py", "--help"] out = StringIO() @@ -406,7 +406,7 @@ def test_lightning_cli_help(): assert "--data.init_args.data_dir" in out.getvalue() -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_print_config(): cli_args = [ "any.py", @@ -426,7 +426,7 @@ def test_lightning_cli_print_config(): assert outval["data"]["class_path"] == "tests.helpers.boring_model.BoringDataModule" -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_submodules(tmpdir): class MainModule(BoringModel): def __init__( @@ -464,7 +464,7 @@ def __init__( assert isinstance(cli.model.submodule2, BoringModel) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required") def test_lightning_cli_torch_modules(tmpdir): class TestModule(BoringModel): @@ -530,7 +530,7 @@ def __init__( self.num_classes = 5 # only available after instantiation -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_link_arguments(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): @@ -579,7 +579,7 @@ def on_exception(self, execption): raise execption -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("logger", (False, True)) @pytest.mark.parametrize( "trainer_kwargs", @@ -612,7 +612,7 @@ def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): assert os.path.isfile(config_path) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_cli_config_overwrite(tmpdir): trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} @@ -624,7 +624,7 @@ def test_cli_config_overwrite(tmpdir): LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_optimizer(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): @@ -648,7 +648,7 @@ def add_arguments_to_parser(self, parser): assert len(cli.trainer.lr_schedulers) == 0 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): @@ -672,7 +672,7 @@ def add_arguments_to_parser(self, parser): assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): @@ -704,7 +704,7 @@ def add_arguments_to_parser(self, parser): assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): diff --git a/tests/core/utilities/test_stability.py b/tests/core/utilities/test_stability.py index 343caefbda..cec2fb247c 100644 --- a/tests/core/utilities/test_stability.py +++ b/tests/core/utilities/test_stability.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from flash.core.utilities.imports import _CORE_TESTING +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE from flash.core.utilities.stability import _raise_beta_warning, beta @@ -37,7 +37,7 @@ def _beta_func_custom_message(): pass -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( "callable, match", [ diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py index 035e1ec7ea..8b4f521619 100644 --- a/tests/examples/test_integrations.py +++ b/tests/examples/test_integrations.py @@ -17,7 +17,12 @@ import pytest -from flash.core.utilities.imports import _BAAL_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _LEARN2LEARN_AVAILABLE +from flash.core.utilities.imports import ( + _BAAL_AVAILABLE, + _FIFTYONE_AVAILABLE, + _LEARN2LEARN_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, +) from tests.examples.utils import run_test root = Path(__file__).parent.parent.parent @@ -31,20 +36,22 @@ "fiftyone", "image_classification.py", marks=pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" + not (_TOPIC_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" ), ), pytest.param( "fiftyone", "object_detection.py", marks=pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" + not (_TOPIC_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" ), ), pytest.param( "baal", "image_classification_active_learning.py", - marks=pytest.mark.skipif(not (_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="baal library isn't installed"), + marks=pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="baal library isn't installed" + ), ), pytest.param( "learn2learn", @@ -52,7 +59,7 @@ marks=[ pytest.mark.skip("MiniImagenet broken: https://github.com/learnables/learn2learn/issues/291"), pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _LEARN2LEARN_AVAILABLE), reason="learn2learn isn't installed" + not (_TOPIC_IMAGE_AVAILABLE and _LEARN2LEARN_AVAILABLE), reason="learn2learn isn't installed" ), ], ), diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index e423c34b2c..ae87587603 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -20,17 +20,18 @@ import torch from flash.core.utilities.imports import ( - _AUDIO_TESTING, - _CORE_TESTING, - _GRAPH_TESTING, + _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, - _IMAGE_AVAILABLE, - _IMAGE_EXTRAS_TESTING, - _IMAGE_TESTING, - _POINTCLOUD_TESTING, - _TABULAR_TESTING, - _TEXT_TESTING, - _VIDEO_TESTING, + _SEGMENTATION_MODELS_AVAILABLE, + _TOPIC_AUDIO_AVAILABLE, + _TOPIC_CORE_AVAILABLE, + _TOPIC_GRAPH_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, + _TOPIC_POINTCLOUD_AVAILABLE, + _TOPIC_TABULAR_AVAILABLE, + _TOPIC_TEXT_AVAILABLE, + _TOPIC_VIDEO_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_9, _VISSL_AVAILABLE, ) from tests.examples.utils import run_test @@ -45,126 +46,133 @@ [ pytest.param( "audio_classification.py", - marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed"), ), pytest.param( "speech_recognition.py", - marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed"), ), pytest.param( "image_classification.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), ), pytest.param( "image_classification_multi_label.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), ), pytest.param( "image_embedder.py", marks=[ - pytest.mark.skipif( - not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="image libraries aren't installed" - ), + pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), + pytest.mark.skipif(not _VISSL_AVAILABLE, reason="VISSL package isn't installed"), pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU"), ], ), pytest.param( "object_detection.py", marks=pytest.mark.skipif( - not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" + not (_TOPIC_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" ), ), pytest.param( "instance_segmentation.py", - marks=pytest.mark.skipif( - not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" - ), + marks=[ + pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), + pytest.mark.skipif(not _ICEDATA_AVAILABLE, reason="icedata package isn't installed"), + ], ), pytest.param( "keypoint_detection.py", - marks=pytest.mark.skipif( - not (_IMAGE_EXTRAS_TESTING and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" - ), + marks=[ + pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), + pytest.mark.skipif(not _ICEDATA_AVAILABLE, reason="icedata package isn't installed"), + ], ), pytest.param( "question_answering.py", - marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), pytest.param( "semantic_segmentation.py", - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), + marks=[ + pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), + pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="Segmentation package isn't installed"), + pytest.mark.skipif(not _TORCHVISION_GREATER_EQUAL_0_9, reason="Newer version of TV is needed."), + ], ), pytest.param( "style_transfer.py", marks=[ - pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"), + pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), pytest.mark.skipif(torch.cuda.device_count() >= 2, reason="PyStiche doesn't support DDP"), ], ), pytest.param( - "summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") + "summarization.py", + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), pytest.param( "tabular_classification.py", - marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed"), ), pytest.param( "tabular_regression.py", - marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed"), ), pytest.param( "tabular_forecasting.py", - marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed"), ), pytest.param( "template.py", marks=[ - pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core."), + pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core."), pytest.mark.skipif(os.name == "posix", reason="Flaky on Mac OS (CI)"), pytest.mark.skipif(sys.version_info >= (3, 9), reason="Undiagnosed segmentation fault in 3.9"), ], ), pytest.param( "text_classification.py", - marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), pytest.param( "text_embedder.py", - marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), # pytest.param( # "text_classification_multi_label.py", - # marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") + # marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed") # ), pytest.param( "translation.py", marks=[ - pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), + pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), pytest.mark.skipif(os.name == "nt", reason="Encoding issues on Windows"), ], ), pytest.param( "video_classification.py", - marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="video libraries aren't installed"), ), pytest.param( "pointcloud_segmentation.py", - marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed"), ), pytest.param( "pointcloud_detection.py", - marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed"), ), pytest.param( "graph_classification.py", - marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph libraries aren't installed"), ), pytest.param( "graph_embedder.py", - marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"), + marks=pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph libraries aren't installed"), ), ], ) @forked +@pytest.mark.skipif(sys.platform == "darwin", reason="Fatal Python error: Illegal instruction") # fixme def test_example(tmpdir, file): run_test(str(root / "examples" / file)) diff --git a/tests/examples/utils.py b/tests/examples/utils.py index e341114711..13784f6471 100644 --- a/tests/examples/utils.py +++ b/tests/examples/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import subprocess import sys from typing import List, Optional, Tuple @@ -41,8 +42,9 @@ def call_script( except subprocess.TimeoutExpired: p.kill() stdout, stderr = p.communicate() - stdout = stdout.decode("utf-8") - stderr = stderr.decode("utf-8") + encoding = "windows-1252" if os.name == "nt" else "utf-8" + stdout = stdout.decode(encoding) + stderr = stderr.decode(encoding) with open(filepath, "w") as modified: modified.writelines(data) diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py index 76ad9eea86..0d35fb232a 100644 --- a/tests/graph/classification/test_data.py +++ b/tests/graph/classification/test_data.py @@ -14,11 +14,11 @@ import pytest from flash import DataKeys -from flash.core.utilities.imports import _GRAPH_AVAILABLE, _GRAPH_TESTING, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE, _TORCHVISION_AVAILABLE from flash.graph.classification.data import GraphClassificationData from flash.graph.classification.input_transform import GraphClassificationInputTransform, PyGTransformAdapter -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric.datasets import TUDataset from torch_geometric.transforms import OneHotDegree @@ -26,7 +26,7 @@ from torchvision import transforms as T -@pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph libraries aren't installed.") class TestGraphClassificationData: """Tests ``GraphClassificationData``.""" diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index 4e358388b7..9168ed1602 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -20,13 +20,13 @@ from flash import RunningStage, Trainer from flash.core.data.data_module import DataModule from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _GRAPH_AVAILABLE, _GRAPH_TESTING +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE from flash.graph.classification import GraphClassifier from flash.graph.classification.input import GraphClassificationDatasetInput from flash.graph.classification.input_transform import GraphClassificationInputTransform from tests.helpers.task_tester import TaskTester -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric import datasets from torch_geometric.data import Batch, Data @@ -35,8 +35,8 @@ class TestGraphClassifier(TaskTester): task = GraphClassifier task_kwargs = {"num_features": 1, "num_classes": 2} cli_command = "graph_classification" - is_testing = _GRAPH_TESTING - is_available = _GRAPH_AVAILABLE + is_testing = _TOPIC_GRAPH_AVAILABLE + is_available = _TOPIC_GRAPH_AVAILABLE # TODO: Resolve JIT issues scriptable = False @@ -67,7 +67,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +@pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="pytorch geometric isn't installed") def test_predict_dataset(tmpdir): """Tests that we can generate predictions from a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py index e843c561c4..78bf198a0f 100644 --- a/tests/graph/embedding/test_model.py +++ b/tests/graph/embedding/test_model.py @@ -19,14 +19,14 @@ from flash import RunningStage, Trainer from flash.core.data.data_module import DataModule -from flash.core.utilities.imports import _GRAPH_AVAILABLE, _GRAPH_TESTING +from flash.core.utilities.imports import _TOPIC_GRAPH_AVAILABLE from flash.graph.classification.input import GraphClassificationDatasetInput from flash.graph.classification.input_transform import GraphClassificationInputTransform from flash.graph.classification.model import GraphClassifier from flash.graph.embedding.model import GraphEmbedder from tests.helpers.task_tester import TaskTester -if _GRAPH_AVAILABLE: +if _TOPIC_GRAPH_AVAILABLE: from torch_geometric import datasets from torch_geometric.data import Batch, Data from torch_geometric.nn.models import GCN @@ -37,8 +37,8 @@ class TestGraphEmbedder(TaskTester): task = GraphEmbedder task_args = (GCN(in_channels=1, hidden_channels=512, num_layers=4),) - is_testing = _GRAPH_TESTING - is_available = _GRAPH_AVAILABLE + is_testing = _TOPIC_GRAPH_AVAILABLE + is_available = _TOPIC_GRAPH_AVAILABLE # TODO: Resolve JIT issues scriptable = False @@ -55,14 +55,14 @@ def check_forward_output(self, output: Any): assert output.shape == torch.Size([1, 512]) -@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +@pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="pytorch geometric isn't installed") def test_smoke(): """A simple test that the class can be instantiated from a GraphClassifier backbone.""" model = GraphEmbedder(GraphClassifier(num_features=1, num_classes=1).backbone) assert model is not None -@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +@pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="pytorch geometric isn't installed") def test_not_trainable(tmpdir): """Tests that the model gives an error when training, validating, or testing.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") @@ -85,7 +85,7 @@ def test_not_trainable(tmpdir): trainer.test(model, datamodule=datamodule) -@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +@pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="pytorch geometric isn't installed") def test_predict_dataset(tmpdir): """Tests that we can generate embeddings from a pytorch geometric dataset.""" tudataset = datasets.TUDataset(root=tmpdir, name="KKI") diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 8965228838..74cd0603be 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -22,7 +22,7 @@ from torch.utils.data import SequentialSampler import flash -from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _BAAL_AVAILABLE, _BAAL_GREATER_EQUAL_1_5_2, _TOPIC_IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop from tests.image.classification.test_data import _rand_image @@ -61,7 +61,9 @@ def simple_datamodule(tmpdir): return dm -@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed.") +@pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed." +) @pytest.mark.parametrize("initial_num_labels, query_size", [(0, 5), (5, 5)]) def test_active_learning_training(simple_datamodule, initial_num_labels, query_size): seed_everything(42) @@ -125,7 +127,9 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s assert len(active_learning_dm.val_dataloader()) == 5 -@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed.") +@pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed." +) def test_no_validation_loop(simple_datamodule): active_learning_dm = ActiveLearningDataModule( simple_datamodule, diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index e92bdac9fd..bdaf2aa987 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -25,10 +25,9 @@ from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, - _IMAGE_AVAILABLE, - _IMAGE_TESTING, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, ) from flash.image import ImageClassificationData, ImageClassificationInputTransform @@ -50,7 +49,7 @@ def _rand_image(size: Tuple[int, int] = None): return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_filepaths_smoke(tmpdir): tmpdir = Path(tmpdir) @@ -79,7 +78,7 @@ def test_from_filepaths_smoke(tmpdir): assert sorted(list(labels.numpy())) == [1, 2] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_data_frame_smoke(tmpdir): tmpdir = Path(tmpdir) @@ -131,7 +130,7 @@ def test_from_data_frame_smoke(tmpdir): assert imgs.shape == (1, 3, 196, 196) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir) @@ -178,7 +177,7 @@ def test_from_filepaths_list_image_paths(tmpdir): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) @@ -214,7 +213,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["per_sample_transform", "per_batch_transform"]) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_subplots_exceding_max_cols(tmpdir): tmpdir = Path(tmpdir) @@ -248,7 +247,7 @@ def test_from_filepaths_visualise_subplots_exceding_max_cols(tmpdir): dm.show_train_batch(["per_sample_transform", "per_batch_transform"]) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_subplots_single_image(tmpdir): tmpdir = Path(tmpdir) @@ -282,7 +281,7 @@ def test_from_filepaths_visualise_subplots_single_image(tmpdir): dm.show_train_batch(["per_sample_transform", "per_batch_transform"]) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -317,7 +316,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): dm.show_val_batch("per_batch_transform") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -338,7 +337,7 @@ def test_from_folders_only_train(tmpdir): assert labels == 0 -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders_train_val(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -376,7 +375,7 @@ def test_from_folders_train_val(tmpdir): assert list(labels.numpy()) == [0, 0] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_filepaths_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -418,7 +417,7 @@ def test_from_filepaths_multilabel(tmpdir): torch.testing.assert_allclose(labels, torch.tensor(test_labels)) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize( "data,from_function", [ @@ -461,7 +460,7 @@ def test_from_data(data, from_function): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_from_fiftyone(tmpdir): tmpdir = Path(tmpdir) @@ -518,7 +517,7 @@ def test_from_fiftyone(tmpdir): assert sorted(list(labels.numpy())) == [0, 1] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_datasets(): img_data = ImageClassificationData.from_datasets( train_dataset=FakeData(size=3, num_classes=2), @@ -566,7 +565,7 @@ def single_target_csv(image_tmpdir): return str(image_tmpdir / "metadata.csv") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_csv_single_target(single_target_csv): img_data = ImageClassificationData.from_csv( "image", @@ -594,7 +593,7 @@ def multi_target_csv(image_tmpdir): return str(image_tmpdir / "metadata.csv") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_csv_multi_target(multi_target_csv): img_data = ImageClassificationData.from_csv( "image", @@ -621,7 +620,7 @@ def bad_csv_no_image(image_tmpdir): return str(image_tmpdir / "metadata.csv") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_bad_csv_no_image(bad_csv_no_image): bad_file = os.path.join(os.path.dirname(bad_csv_no_image), "image_3") with pytest.raises(ValueError, match=f"File ID `image_3` resolved to `{bad_file}`, which does not exist."): @@ -635,7 +634,7 @@ def test_from_bad_csv_no_image(bad_csv_no_image): _ = next(iter(img_data.train_dataloader())) -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_mixup(single_target_csv): @dataclass class MyTransform(ImageClassificationInputTransform): diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py index b9ea6c85e9..8d8eb9aa5f 100644 --- a/tests/image/classification/test_data_model_integration.py +++ b/tests/image/classification/test_data_model_integration.py @@ -17,7 +17,7 @@ import pytest from flash import Trainer -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_TESTING, _PIL_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier if _PIL_AVAILABLE: @@ -31,7 +31,7 @@ def _rand_image(): return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_classification(tmpdir): tmpdir = Path(tmpdir) @@ -56,7 +56,7 @@ def test_classification(tmpdir): trainer.finetune(model, datamodule=data, strategy="freeze") -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_classification_fiftyone(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index b3a146c925..c72c3ae109 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -20,7 +20,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _IMAGE_TESTING, _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE from flash.image import ImageClassifier from tests.helpers.task_tester import TaskTester @@ -48,8 +48,8 @@ class TestImageClassifier(TaskTester): task = ImageClassifier task_args = (2,) cli_command = "image_classification" - is_testing = _IMAGE_TESTING - is_available = _IMAGE_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE marks = { "test_fit": [ @@ -113,13 +113,13 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_non_existent_backbone(): with pytest.raises(KeyError): ImageClassifier(2, backbone="i am never going to implement this lol") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_freeze(): model = ImageClassifier(2) model.freeze() @@ -127,7 +127,7 @@ def test_freeze(): assert p.requires_grad is False -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_unfreeze(): model = ImageClassifier(2) model.unfreeze() @@ -135,7 +135,7 @@ def test_unfreeze(): assert p.requires_grad is True -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_multilabel(tmpdir): num_classes = 4 ds = DummyMultiLabelDataset(num_classes) @@ -149,7 +149,7 @@ def test_multilabel(tmpdir): assert len(predictions[0]) == num_classes -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = ImageClassifier(2) diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py index 82832b7e40..99231a016d 100644 --- a/tests/image/classification/test_training_strategies.py +++ b/tests/image/classification/test_training_strategies.py @@ -20,7 +20,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_TESTING, _LEARN2LEARN_AVAILABLE, _PL_GREATER_EQUAL_1_6_0 +from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE, _PL_GREATER_EQUAL_1_6_0, _TOPIC_IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.adapters import TRAINING_STRATEGIES from tests.image.classification.test_data import _rand_image @@ -39,7 +39,7 @@ def __len__(self) -> int: return 2 -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_default_strategies(tmpdir): num_classes = 10 ds = DummyDataset() diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 8c3fc3e9da..d5c04a2972 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -18,7 +18,7 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _PIL_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, _PIL_AVAILABLE from flash.image.detection.data import ObjectDetectionData if _PIL_AVAILABLE: @@ -160,7 +160,7 @@ def _create_synth_fiftyone_dataset(tmpdir): return dataset -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed for testing") def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) @@ -193,7 +193,6 @@ def test_image_detector_data_from_coco(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_image_detector_data_from_fiftyone(tmpdir): train_dataset = _create_synth_fiftyone_dataset(tmpdir) @@ -224,7 +223,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed for testing") def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) datamodule = ObjectDetectionData.from_files( @@ -235,7 +234,7 @@ def test_image_detector_data_from_files(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed for testing") def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) datamodule = ObjectDetectionData.from_folders( diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index 824782275a..fd45a57073 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -17,7 +17,7 @@ import torch import flash -from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _PIL_AVAILABLE +from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _PIL_AVAILABLE from flash.image import ObjectDetector from flash.image.detection import ObjectDetectionData @@ -33,7 +33,7 @@ from tests.image.detection.test_data import _create_synth_fiftyone_dataset -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _COCO_AVAILABLE, reason="coco is not installed for testing") @pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn"), ("faster_rcnn", "resnet18_fpn")]) def test_detection(tmpdir, head, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) @@ -55,7 +55,6 @@ def test_detection(tmpdir, head, backbone): trainer.predict(model, datamodule=datamodule) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) def test_detection_fiftyone(tmpdir, head, backbone): diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 13bfb67c2d..bc2fde9a12 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -23,7 +23,12 @@ from flash.core.data.io.input import DataKeys from flash.core.integrations.icevision.transforms import IceVisionInputTransform from flash.core.trainer import Trainer -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _SERVE_TESTING +from flash.core.utilities.imports import ( + _EFFDET_AVAILABLE, + _ICEVISION_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, + _TOPIC_SERVE_AVAILABLE, +) from flash.image import ObjectDetector from tests.helpers.task_tester import TaskTester @@ -68,12 +73,13 @@ def __getitem__(self, idx): return sample +@pytest.mark.skipif(not _EFFDET_AVAILABLE, reason="effdet is not installed for testing") class TestObjectDetector(TaskTester): task = ObjectDetector task_kwargs = {"num_classes": 2} cli_command = "object_detection" - is_testing = _IMAGE_EXTRAS_TESTING - is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE # TODO: Resolve JIT support traceable = False @@ -109,7 +115,7 @@ def example_test_sample(self): @pytest.mark.parametrize("head", ["retinanet"]) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_predict(tmpdir, head): model = ObjectDetector(num_classes=2, head=head, pretrained=False) ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10) @@ -126,11 +132,7 @@ def test_predict(tmpdir, head): ) trainer.fit(model, dl) - dl = model.process_predict_dataset( - ds, - 2, - input_transform=input_transform, - ) + dl = model.process_predict_dataset(ds, 2, input_transform=input_transform) predictions = trainer.predict(model, dl, output="preds") assert len(predictions[0][0]["bboxes"]) > 0 model.predict_kwargs = {"detection_threshold": 2} @@ -138,7 +140,7 @@ def test_predict(tmpdir, head): assert len(predictions[0][0]["bboxes"]) == 0 -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = ObjectDetector(2) diff --git a/tests/image/detection/test_output.py b/tests/image/detection/test_output.py index 8d0b912caa..7ad27f60d6 100644 --- a/tests/image/detection/test_output.py +++ b/tests/image/detection/test_output.py @@ -3,11 +3,11 @@ import torch from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image.detection.output import FiftyOneDetectionLabelsOutput -@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") class TestFiftyOneDetectionLabelsOutput: @staticmethod diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 9f2ec8e476..26e77dfa84 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -18,7 +18,7 @@ from torch import Tensor import flash -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _IMAGE_TESTING, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TORCHVISION_AVAILABLE, _VISSL_AVAILABLE from flash.image import ImageClassificationData, ImageEmbedder from tests.helpers.task_tester import TaskTester @@ -33,8 +33,8 @@ class TestImageEmbedder(TaskTester): task_kwargs = dict( backbone="resnet18", ) - is_testing = _IMAGE_TESTING - is_available = _IMAGE_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE # TODO: Resolve JIT script issues scriptable = False @@ -49,7 +49,7 @@ def check_forward_output(self, output: Any): @pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU") -@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +@pytest.mark.skipif(not (_TOPIC_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @pytest.mark.parametrize( "backbone, training_strategy, head, pretraining_transform, embedding_size", [ @@ -87,7 +87,7 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform assert prediction.size(0) == embedding_size -@pytest.mark.skipif(not (_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") +@pytest.mark.skipif(not (_TOPIC_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @pytest.mark.parametrize( "backbone, training_strategy, head, pretraining_transform, expected_exception", [ @@ -107,7 +107,7 @@ def test_vissl_training_with_wrong_arguments( ) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="torch vision not installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="torch vision not installed.") @pytest.mark.parametrize( "backbone, embedding_size", [ @@ -131,7 +131,7 @@ def test_only_embedding(backbone, embedding_size): assert prediction.size(0) == embedding_size -@pytest.mark.skipif(not _IMAGE_TESTING, reason="torch vision not installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="torch vision not installed.") def test_not_implemented_steps(): embedder = ImageEmbedder(backbone="resnet18") diff --git a/tests/image/instance_segmentation/__init__.py b/tests/image/instance_segm/__init__.py similarity index 100% rename from tests/image/instance_segmentation/__init__.py rename to tests/image/instance_segm/__init__.py diff --git a/tests/image/instance_segmentation/test_data.py b/tests/image/instance_segm/test_data.py similarity index 88% rename from tests/image/instance_segmentation/test_data.py rename to tests/image/instance_segm/test_data.py index e82562777a..1d00ce8e7e 100644 --- a/tests/image/instance_segmentation/test_data.py +++ b/tests/image/instance_segm/test_data.py @@ -16,13 +16,13 @@ import torch from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_EXTRAS_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE from flash.image.instance_segmentation import InstanceSegmentationData from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) datamodule = InstanceSegmentationData.from_files( @@ -33,7 +33,7 @@ def test_image_detector_data_from_files(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) datamodule = InstanceSegmentationData.from_folders( @@ -44,7 +44,7 @@ def test_image_detector_data_from_folders(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_instance_segmentation_output_transform(): sample = { DataKeys.INPUT: torch.rand(3, 224, 224), diff --git a/tests/image/instance_segmentation/test_model.py b/tests/image/instance_segm/test_model.py similarity index 90% rename from tests/image/instance_segmentation/test_model.py rename to tests/image/instance_segm/test_model.py index 5efd09e4b7..872a5e0e4d 100644 --- a/tests/image/instance_segmentation/test_model.py +++ b/tests/image/instance_segm/test_model.py @@ -22,11 +22,11 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING +from flash.core.utilities.imports import _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image import InstanceSegmentation, InstanceSegmentationData from tests.helpers.task_tester import TaskTester -if _IMAGE_AVAILABLE: +if _TOPIC_IMAGE_AVAILABLE: from PIL import Image COCODataConfig = collections.namedtuple("COCODataConfig", "train_folder train_ann_file predict_folder") @@ -90,12 +90,14 @@ def coco_instances(tmpdir): return COCODataConfig(train_folder, train_ann_file, predict_folder) +@pytest.mark.skipif(not _ICEDATA_AVAILABLE, reason="icedata is not installed for testing") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed for testing") class TestInstanceSegmentation(TaskTester): task = InstanceSegmentation task_kwargs = {"num_classes": 2} cli_command = "instance_segmentation" - is_testing = _IMAGE_EXTRAS_TESTING - is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE # TODO: Resolve JIT support traceable = False @@ -131,7 +133,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize("backbone, head", [("resnet18_fpn", "mask_rcnn")]) def test_model(coco_instances, backbone, head): datamodule = InstanceSegmentationData.from_coco( diff --git a/tests/image/keypoint_detection/test_data.py b/tests/image/keypoint_detection/test_data.py index 5901191b53..7aefde8844 100644 --- a/tests/image/keypoint_detection/test_data.py +++ b/tests/image/keypoint_detection/test_data.py @@ -14,12 +14,12 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_EXTRAS_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE from flash.image.keypoint_detection import KeypointDetectionData from tests.image.detection.test_data import _create_synth_files_dataset, _create_synth_folders_dataset -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) datamodule = KeypointDetectionData.from_files( @@ -30,7 +30,7 @@ def test_image_detector_data_from_files(tmpdir): assert sample[DataKeys.INPUT].shape == (128, 128, 3) -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) datamodule = KeypointDetectionData.from_folders( diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py index 328d235890..be47ad5204 100644 --- a/tests/image/keypoint_detection/test_model.py +++ b/tests/image/keypoint_detection/test_model.py @@ -22,11 +22,11 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_EXTRAS_TESTING +from flash.core.utilities.imports import _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image import KeypointDetectionData, KeypointDetector from tests.helpers.task_tester import TaskTester -if _IMAGE_AVAILABLE: +if _TOPIC_IMAGE_AVAILABLE: from PIL import Image COCODataConfig = collections.namedtuple("COCODataConfig", "train_folder train_ann_file predict_folder") @@ -93,13 +93,15 @@ def coco_keypoints(tmpdir): return COCODataConfig(train_folder, train_ann_file, predict_folder) +@pytest.mark.skipif(not _ICEDATA_AVAILABLE, reason="icedata is not installed for testing") +@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="icevision is not installed for testing") class TestKeypointDetector(TaskTester): task = KeypointDetector task_args = (2,) task_kwargs = {"num_classes": 2} cli_command = "keypoint_detection" - is_testing = _IMAGE_EXTRAS_TESTING - is_available = _IMAGE_AVAILABLE and _ICEVISION_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE # TODO: Resolve JIT support traceable = False @@ -138,7 +140,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _IMAGE_EXTRAS_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize("backbone, head", [("resnet18_fpn", "keypoint_rcnn")]) def test_model(coco_keypoints, backbone, head): datamodule = KeypointDetectionData.from_coco( diff --git a/tests/image/segmentation/__init__.py b/tests/image/semantic_segm/__init__.py similarity index 100% rename from tests/image/segmentation/__init__.py rename to tests/image/semantic_segm/__init__.py diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/semantic_segm/test_backbones.py similarity index 100% rename from tests/image/segmentation/test_backbones.py rename to tests/image/semantic_segm/test_backbones.py diff --git a/tests/image/segmentation/test_data.py b/tests/image/semantic_segm/test_data.py similarity index 94% rename from tests/image/segmentation/test_data.py rename to tests/image/semantic_segm/test_data.py index 81d6e55a32..c865285055 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/semantic_segm/test_data.py @@ -10,10 +10,9 @@ from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, - _IMAGE_AVAILABLE, - _IMAGE_TESTING, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, ) from flash.image import SemanticSegmentation, SemanticSegmentationData @@ -53,13 +52,13 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup class TestSemanticSegmentationData: @staticmethod - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_smoke(): dm = SemanticSegmentationData(batch_size=1) assert dm is not None @staticmethod - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders(tmpdir): tmp_dir = Path(tmpdir) @@ -121,7 +120,7 @@ def test_from_folders(tmpdir): assert labels.shape == (2, 128, 128) @staticmethod - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders_different_extensions(tmpdir): tmp_dir = Path(tmpdir) @@ -183,7 +182,7 @@ def test_from_folders_different_extensions(tmpdir): assert labels.shape == (2, 128, 128) @staticmethod - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders_error(tmpdir): tmp_dir = Path(tmpdir) @@ -218,7 +217,7 @@ def test_from_folders_error(tmpdir): ) @staticmethod - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_files(tmpdir): tmp_dir = Path(tmpdir) @@ -277,7 +276,7 @@ def test_from_files(tmpdir): assert labels.shape == (2, 128, 128) @staticmethod - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_files_warning(tmpdir): tmp_dir = Path(tmpdir) @@ -311,7 +310,7 @@ def test_from_files_warning(tmpdir): ) @staticmethod - @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_from_fiftyone(tmpdir): tmp_dir = Path(tmpdir) @@ -381,7 +380,7 @@ def test_from_fiftyone(tmpdir): assert imgs.shape == (2, 3, 128, 128) @staticmethod - @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_map_labels(tmpdir): tmp_dir = Path(tmpdir) diff --git a/tests/image/segmentation/test_heads.py b/tests/image/semantic_segm/test_heads.py similarity index 93% rename from tests/image/segmentation/test_heads.py rename to tests/image/semantic_segm/test_heads.py index a7c64ab41c..cb18b2bfd1 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/semantic_segm/test_heads.py @@ -16,7 +16,7 @@ import pytest import torch -from flash.core.utilities.imports import _IMAGE_TESTING, _SEGMENTATION_MODELS_AVAILABLE +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image.segmentation import SemanticSegmentation from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS @@ -43,7 +43,7 @@ def test_semantic_segmentation_heads_registry(head): assert res.shape[1] == 10 -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @unittest.mock.patch("flash.image.segmentation.heads.smp") def test_pretrained_weights(mock_smp): mock_smp.create_model = unittest.mock.MagicMock() diff --git a/tests/image/segmentation/test_model.py b/tests/image/semantic_segm/test_model.py similarity index 80% rename from tests/image/segmentation/test_model.py rename to tests/image/semantic_segm/test_model.py index 1d6e7d353c..03069563a5 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/semantic_segm/test_model.py @@ -21,7 +21,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _IMAGE_TESTING, _SERVE_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE from flash.image import SemanticSegmentation from flash.image.segmentation.data import SemanticSegmentationData from tests.helpers.task_tester import TaskTester @@ -31,8 +31,8 @@ class TestSemanticSegmentation(TaskTester): task = SemanticSegmentation task_args = (2,) cli_command = "semantic_segmentation" - is_testing = _IMAGE_TESTING - is_available = _IMAGE_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE scriptable = False @property @@ -59,13 +59,13 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_non_existent_backbone(): with pytest.raises(KeyError): SemanticSegmentation(2, "i am never going to implement this lol") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_freeze(): model = SemanticSegmentation(2) model.freeze() @@ -73,7 +73,7 @@ def test_freeze(): assert p.requires_grad is False -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_unfreeze(): model = SemanticSegmentation(2) model.unfreeze() @@ -81,7 +81,7 @@ def test_unfreeze(): assert p.requires_grad is True -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_predict_tensor(): img = torch.rand(1, 3, 64, 64) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") @@ -93,7 +93,7 @@ def test_predict_tensor(): assert len(out[0][0][0]) == 64 -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") @@ -105,7 +105,7 @@ def test_predict_numpy(): assert len(out[0][0][0]) == 64 -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SemanticSegmentation(2) @@ -113,6 +113,6 @@ def test_serve(): model.serve() -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_available_pretrained_weights(): assert SemanticSegmentation.available_pretrained_weights("resnet18") == ["imagenet", "ssl", "swsl"] diff --git a/tests/image/segmentation/test_output.py b/tests/image/semantic_segm/test_output.py similarity index 86% rename from tests/image/segmentation/test_output.py rename to tests/image/semantic_segm/test_output.py index 78224ed686..f431558e7b 100644 --- a/tests/image/segmentation/test_output.py +++ b/tests/image/semantic_segm/test_output.py @@ -15,12 +15,12 @@ import torch from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_TESTING +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image.segmentation.output import FiftyOneSegmentationLabelsOutput, SegmentationLabelsOutput class TestSemanticSegmentationLabelsOutput: - @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") @staticmethod def test_smoke(): serial = SegmentationLabelsOutput() @@ -28,7 +28,7 @@ def test_smoke(): assert serial.labels_map is None assert serial.visualize is False - @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") @staticmethod def test_exception(): serial = SegmentationLabelsOutput() @@ -41,7 +41,7 @@ def test_exception(): sample = torch.zeros(2, 3) serial.transform(sample) - @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") @staticmethod def test_serialize(): serial = SegmentationLabelsOutput() @@ -54,7 +54,7 @@ def test_serialize(): assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3 - @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @staticmethod def test_serialize_fiftyone(): diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py index 37b27580e3..51fea84407 100644 --- a/tests/image/style_transfer/test_model.py +++ b/tests/image/style_transfer/test_model.py @@ -19,7 +19,7 @@ from torch import Tensor from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _IMAGE_AVAILABLE, _IMAGE_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE from flash.image.style_transfer import StyleTransfer from tests.helpers.task_tester import TaskTester @@ -28,8 +28,8 @@ class TestStyleTransfer(TaskTester): task = StyleTransfer cli_command = "style_transfer" - is_testing = _IMAGE_TESTING - is_available = _IMAGE_AVAILABLE + is_testing = _TOPIC_IMAGE_AVAILABLE + is_available = _TOPIC_IMAGE_AVAILABLE # TODO: loss_fn and perceptual_loss can't be jitted scriptable = False @@ -48,7 +48,7 @@ def example_train_sample(self): return {DataKeys.INPUT: torch.rand(3, 224, 224)} -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.xfail(URLError, reason="Connection timed out for download.pystiche.org") def test_style_transfer_task(): model = StyleTransfer( diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 3d9703a2ec..842027b571 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -15,7 +15,7 @@ import pytest -from flash.core.utilities.imports import _IMAGE_TESTING +from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE from flash.core.utilities.url_error import catch_url_error from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES @@ -23,9 +23,11 @@ @pytest.mark.parametrize( ["backbone", "expected_num_features"], [ - pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), - pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No timm")), - pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), + pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No torchvision")), + pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No timm")), + pytest.param( + "mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No torchvision") + ), ], ) def test_image_classifier_backbones_registry(backbone, expected_num_features): @@ -42,9 +44,11 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): "resnet50", "supervised", 2048, - marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision"), + marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No torchvision"), + ), + pytest.param( + "resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No torchvision") ), - pytest.param("resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")), ], ) def test_pretrained_weights_registry(backbone, pretrained, expected_num_features): diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py index 4e5edd5d99..4515d17a50 100644 --- a/tests/pointcloud/detection/test_data.py +++ b/tests/pointcloud/detection/test_data.py @@ -20,14 +20,14 @@ from flash import Trainer from flash.core.data.io.input import DataKeys from flash.core.data.utils import download_data -from flash.core.utilities.imports import _POINTCLOUD_TESTING +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData -if _POINTCLOUD_TESTING: +if _TOPIC_POINTCLOUD_AVAILABLE: from flash.pointcloud.detection.open3d_ml.backbones import ObjectDetectBatchCollator -@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed") def test_pointcloud_object_detection_data(tmpdir): seed_everything(52) diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py index 556542c3aa..aa088928e5 100644 --- a/tests/pointcloud/detection/test_model.py +++ b/tests/pointcloud/detection/test_model.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _POINTCLOUD_TESTING +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.pointcloud.detection import PointCloudObjectDetector from tests.helpers.task_tester import TaskTester @@ -22,11 +22,11 @@ class TestPointCloudObjectDetector(TaskTester): task = PointCloudObjectDetector task_args = (2,) cli_command = "pointcloud_detection" - is_testing = _POINTCLOUD_TESTING - is_available = _POINTCLOUD_AVAILABLE + is_testing = _TOPIC_POINTCLOUD_AVAILABLE + is_available = _TOPIC_POINTCLOUD_AVAILABLE -@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed") def test_backbones(): backbones = PointCloudObjectDetector.available_backbones() assert backbones == ["pointpillars", "pointpillars_kitti"] diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py index 3f7f953298..3e2314ecaa 100644 --- a/tests/pointcloud/segmentation/test_data.py +++ b/tests/pointcloud/segmentation/test_data.py @@ -20,11 +20,11 @@ from flash import Trainer from flash.core.data.io.input import DataKeys from flash.core.data.utils import download_data -from flash.core.utilities.imports import _POINTCLOUD_TESTING +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData -@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed") def test_pointcloud_segmentation_data(tmpdir): seed_everything(52) diff --git a/tests/pointcloud/segmentation/test_datasets.py b/tests/pointcloud/segmentation/test_datasets.py index ac98b2ec5b..6c09f9da95 100644 --- a/tests/pointcloud/segmentation/test_datasets.py +++ b/tests/pointcloud/segmentation/test_datasets.py @@ -15,11 +15,11 @@ import pytest -from flash.core.utilities.imports import _POINTCLOUD_TESTING +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.pointcloud.segmentation.datasets import LyftDataset, SemanticKITTIDataset -@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed") @patch("flash.pointcloud.segmentation.datasets.os.system") def test_datasets(mock_system): LyftDataset("data") diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py index 2ec1163acb..60456475f3 100644 --- a/tests/pointcloud/segmentation/test_model.py +++ b/tests/pointcloud/segmentation/test_model.py @@ -14,7 +14,7 @@ import pytest import torch -from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _POINTCLOUD_TESTING +from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE from flash.pointcloud.segmentation import PointCloudSegmentation from tests.helpers.task_tester import TaskTester @@ -23,17 +23,17 @@ class TestPointCloudSegmentation(TaskTester): task = PointCloudSegmentation task_args = (2,) cli_command = "pointcloud_segmentation" - is_testing = _POINTCLOUD_TESTING - is_available = _POINTCLOUD_AVAILABLE + is_testing = _TOPIC_POINTCLOUD_AVAILABLE + is_available = _TOPIC_POINTCLOUD_AVAILABLE -@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed") def test_backbones(): backbones = PointCloudSegmentation.available_backbones() assert backbones == ["randlanet", "randlanet_s3dis", "randlanet_semantic_kitti", "randlanet_toronto3d"] -@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +@pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed") @pytest.mark.parametrize( "backbone", [ diff --git a/tests/serve/__init__.py b/tests/serve/__init__.py new file mode 100644 index 0000000000..81a3bbadf5 --- /dev/null +++ b/tests/serve/__init__.py @@ -0,0 +1 @@ +"""This is just placeholder to have parity with domains/topics.""" diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index 68413375c2..6776ad35b7 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -18,9 +18,9 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE -if _TABULAR_TESTING: +if _TOPIC_TABULAR_AVAILABLE: import pandas as pd from flash.tabular import TabularClassificationData @@ -59,7 +59,7 @@ TEST_DF_2 = pd.DataFrame(data=TEST_DICT_2) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_categorize(): codes = _generate_codes(TEST_DF_1, ["category"]) assert codes == {"category": ["a", "b", "c"]} @@ -71,7 +71,7 @@ def test_categorize(): assert list(df["category"]) == [0, 0, 0] -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_normalize(): num_input = ["scalar_a", "scalar_b"] mean, std = _compute_normalization(TEST_DF_1, num_input) @@ -79,7 +79,7 @@ def test_normalize(): assert np.allclose(df[num_input].mean(), 0.0) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_normalize_large_array_dtype_fp16(): # See: https://github.com/Lightning-AI/lightning-flash/pull/1359 for the motivation behind this test arr = np.linspace(0, 10000, 10000, dtype=np.float16) @@ -90,7 +90,7 @@ def test_normalize_large_array_dtype_fp16(): assert np.allclose(df[col_name].mean(), 0.0) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_embedding_sizes(): self = Mock() @@ -110,7 +110,7 @@ def test_embedding_sizes(): assert es == [(100_000, 17), (1_000_000, 31)] -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_categorical_target(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() @@ -138,7 +138,7 @@ def test_categorical_target(tmpdir): assert target.shape == (1,) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_from_data_frame(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() @@ -162,7 +162,7 @@ def test_from_data_frame(tmpdir): assert target.shape == (1,) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_from_csv(tmpdir): train_csv = Path(tmpdir) / "train.csv" val_csv = test_csv = Path(tmpdir) / "valid.csv" @@ -189,7 +189,7 @@ def test_from_csv(tmpdir): assert target.shape == (1,) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_from_dicts(): dm = TabularClassificationData.from_dicts( categorical_fields=["category"], @@ -210,7 +210,7 @@ def test_from_dicts(): assert target.shape == (1,) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_from_lists(): dm = TabularClassificationData.from_lists( categorical_fields=["category"], @@ -231,7 +231,7 @@ def test_from_lists(): assert target.shape == (1,) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular dependencies are required") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular dependencies are required") def test_empty_inputs(): train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index 3294b8e22d..d9c86aa081 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -14,10 +14,10 @@ import pytest import pytorch_lightning as pl -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE from flash.tabular import TabularClassificationData, TabularClassifier -if _TABULAR_AVAILABLE: +if _TOPIC_TABULAR_AVAILABLE: import pandas as pd TEST_DF_1 = pd.DataFrame( @@ -30,7 +30,7 @@ ) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( "backbone,fields", [ diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 674772c9e4..5b6347dc70 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -21,7 +21,7 @@ import flash from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _SERVE_TESTING, _TABULAR_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TABULAR_AVAILABLE from flash.tabular.classification.data import TabularClassificationData from flash.tabular.classification.model import TabularClassifier from tests.helpers.task_tester import StaticDataset, TaskTester @@ -38,8 +38,8 @@ class TestTabularClassifier(TaskTester): "backbone": "tabnet", } cli_command = "tabular_classification" - is_testing = _TABULAR_TESTING - is_available = _TABULAR_AVAILABLE + is_testing = _TOPIC_TABULAR_AVAILABLE + is_available = _TOPIC_TABULAR_AVAILABLE # TODO: Resolve JIT issues scriptable = False @@ -140,7 +140,7 @@ def test_init_train_no_cat(self, backbone, tmpdir): trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @pytest.mark.parametrize( "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] ) diff --git a/tests/tabular/forecasting/test_data.py b/tests/tabular/forecasting/test_data.py index 79683a65fa..ad640da9b8 100644 --- a/tests/tabular/forecasting/test_data.py +++ b/tests/tabular/forecasting/test_data.py @@ -15,11 +15,11 @@ import pytest -from flash.core.utilities.imports import _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE from flash.tabular.forecasting import TabularForecastingData -@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="Tabular libraries aren't installed.") @patch("flash.tabular.forecasting.input.TimeSeriesDataSet") def test_from_data_frame_time_series_data_set_single_call(patch_time_series_data_set): """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected parameters @@ -48,7 +48,7 @@ def test_from_data_frame_time_series_data_set_single_call(patch_time_series_data ) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="Tabular libraries aren't installed.") @patch("flash.tabular.forecasting.input.TimeSeriesDataSet") def test_from_data_frame_time_series_data_set_multi_call(patch_time_series_data_set): """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected parameters @@ -82,7 +82,7 @@ def test_from_data_frame_time_series_data_set_multi_call(patch_time_series_data_ ) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="Tabular libraries aren't installed.") def test_from_data_frame_misconfiguration(): """Tests that a ``ValueError`` is raised when ``TabularForecastingData`` is constructed without parameters.""" with pytest.raises(ValueError, match="evaluation or inference requires parameters"): diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py index a786b8196b..ca0fe2e41a 100644 --- a/tests/tabular/forecasting/test_model.py +++ b/tests/tabular/forecasting/test_model.py @@ -19,11 +19,11 @@ import flash from flash import DataKeys -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE from flash.tabular.forecasting import TabularForecaster from tests.helpers.task_tester import StaticDataset, TaskTester -if _TABULAR_AVAILABLE: +if _TOPIC_TABULAR_AVAILABLE: from pytorch_forecasting.data import EncoderNormalizer, NaNLabelEncoder else: EncoderNormalizer = object @@ -72,8 +72,8 @@ class TestTabularForecaster(TaskTester): "backbone_kwargs": {"widths": [32, 512], "backcast_loss_ratio": 0.1}, } cli_command = "tabular_forecasting" - is_testing = _TABULAR_TESTING - is_available = _TABULAR_AVAILABLE + is_testing = _TOPIC_TABULAR_AVAILABLE + is_available = _TOPIC_TABULAR_AVAILABLE # # TODO: Resolve JIT issues scriptable = False diff --git a/tests/tabular/regression/test_data_model_integration.py b/tests/tabular/regression/test_data_model_integration.py index 21e85356c7..3cfeb6fdb2 100644 --- a/tests/tabular/regression/test_data_model_integration.py +++ b/tests/tabular/regression/test_data_model_integration.py @@ -14,10 +14,10 @@ import pytest import pytorch_lightning as pl -from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_TABULAR_AVAILABLE from flash.tabular import TabularRegressionData, TabularRegressor -if _TABULAR_AVAILABLE: +if _TOPIC_TABULAR_AVAILABLE: import pandas as pd TEST_DICT = { @@ -39,7 +39,7 @@ TEST_DF = pd.DataFrame(data=TEST_DICT) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( "backbone,fields", [ @@ -72,7 +72,7 @@ def test_regression_data_frame(backbone, fields, tmpdir): trainer.fit(model, data) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( "backbone,fields", [ @@ -102,7 +102,7 @@ def test_regression_dicts(backbone, fields, tmpdir): trainer.fit(model, data) -@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( "backbone,fields", [ diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py index 6dbfe32a55..39789d7681 100644 --- a/tests/tabular/regression/test_model.py +++ b/tests/tabular/regression/test_model.py @@ -21,7 +21,7 @@ import flash from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _SERVE_TESTING, _TABULAR_AVAILABLE, _TABULAR_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TABULAR_AVAILABLE from flash.tabular import TabularRegressionData, TabularRegressor from tests.helpers.task_tester import StaticDataset, TaskTester @@ -36,8 +36,8 @@ class TestTabularRegressor(TaskTester): "backbone": "tabnet", } cli_command = "tabular_regression" - is_testing = _TABULAR_TESTING - is_available = _TABULAR_AVAILABLE + is_testing = _TOPIC_TABULAR_AVAILABLE + is_available = _TOPIC_TABULAR_AVAILABLE # TODO: Resolve JIT issues scriptable = False @@ -138,7 +138,7 @@ def test_init_train_no_cat(self, backbone, tmpdir): trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @pytest.mark.parametrize( "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] ) diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py index d83bac5756..fc2963e8d6 100644 --- a/tests/template/classification/test_data.py +++ b/tests/template/classification/test_data.py @@ -15,14 +15,14 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _CORE_TESTING, _SKLEARN_AVAILABLE +from flash.core.utilities.imports import _SKLEARN_AVAILABLE, _TOPIC_CORE_AVAILABLE from flash.template.classification.data import TemplateData if _SKLEARN_AVAILABLE: from sklearn import datasets -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") class TestTemplateData: """Tests ``TemplateData``.""" diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index cfcb586a58..41f6d45e09 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -20,7 +20,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _CORE_TESTING, _SKLEARN_AVAILABLE +from flash.core.utilities.imports import _SKLEARN_AVAILABLE, _TOPIC_CORE_AVAILABLE from flash.template import TemplateSKLearnClassifier from flash.template.classification.data import TemplateData @@ -49,14 +49,14 @@ def __len__(self) -> int: # ============================== -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_smoke(): """A simple test that the class can be instantiated.""" model = TemplateSKLearnClassifier(num_features=1, num_classes=1) assert model is not None -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("num_classes", [4, 256]) @pytest.mark.parametrize("shape", [(1, 3), (2, 128)]) def test_forward(num_classes, shape): @@ -73,7 +73,7 @@ def test_forward(num_classes, shape): assert out.shape == (shape[0], num_classes) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_train(tmpdir): """Tests that the model can be trained on our ``DummyDataset``.""" model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) @@ -82,7 +82,7 @@ def test_train(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_val(tmpdir): """Tests that the model can be validated on our ``DummyDataset``.""" model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) @@ -91,7 +91,7 @@ def test_val(tmpdir): trainer.validate(model, val_dl) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_test(tmpdir): """Tests that the model can be tested on our ``DummyDataset``.""" model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) @@ -100,7 +100,7 @@ def test_test(tmpdir): trainer.test(model, test_dl) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_predict_numpy(): """Tests that we can generate predictions from a numpy array.""" row = np.random.rand(1, DummyDataset.num_features) @@ -111,7 +111,7 @@ def test_predict_numpy(): assert isinstance(out[0][0], int) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_predict_sklearn(): """Tests that we can generate predictions from a scikit-learn ``Bunch``.""" bunch = datasets.load_iris() @@ -122,7 +122,7 @@ def test_predict_sklearn(): assert isinstance(out[0][0], int) -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") +@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))]) def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "testing_model.pt") diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 1ba45499c9..073c97205b 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -18,10 +18,10 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import TextClassificationData -if _TEXT_AVAILABLE: +if _TOPIC_TEXT_AVAILABLE: from datasets import Dataset TEST_CSV_DATA = """sentence,label @@ -117,7 +117,7 @@ def parquet_data(tmpdir, multilabel: bool): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir, multilabel=False) dm = TextClassificationData.from_csv( @@ -147,7 +147,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_csv_multilabel(tmpdir): csv_path = csv_data(tmpdir, multilabel=True) dm = TextClassificationData.from_csv( @@ -179,7 +179,7 @@ def test_from_csv_multilabel(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir, multilabel=False) dm = TextClassificationData.from_json( @@ -209,7 +209,7 @@ def test_from_json(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json_multilabel(tmpdir): json_path = json_data(tmpdir, multilabel=True) dm = TextClassificationData.from_json( @@ -241,7 +241,7 @@ def test_from_json_multilabel(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json_with_field(tmpdir): json_path = json_data_with_field(tmpdir, multilabel=False) dm = TextClassificationData.from_json( @@ -272,7 +272,7 @@ def test_from_json_with_field(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json_with_field_multilabel(tmpdir): json_path = json_data_with_field(tmpdir, multilabel=True) dm = TextClassificationData.from_json( @@ -305,7 +305,7 @@ def test_from_json_with_field_multilabel(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_parquet(tmpdir): parquet_path = parquet_data(tmpdir, False) dm = TextClassificationData.from_parquet( @@ -335,7 +335,7 @@ def test_from_parquet(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_parquet_multilabel(tmpdir): parquet_path = parquet_data(tmpdir, True) dm = TextClassificationData.from_parquet( @@ -367,7 +367,7 @@ def test_from_parquet_multilabel(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_data_frame(): dm = TextClassificationData.from_data_frame( "sentence", @@ -396,7 +396,7 @@ def test_from_data_frame(): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_data_frame_multilabel(): dm = TextClassificationData.from_data_frame( "sentence", @@ -427,7 +427,7 @@ def test_from_data_frame_multilabel(): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_hf_datasets(): TEST_HF_DATASET_DATA = Dataset.from_pandas(TEST_DATA_FRAME_DATA) dm = TextClassificationData.from_hf_datasets( @@ -457,7 +457,7 @@ def test_from_hf_datasets(): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_hf_datasets_multilabel(): TEST_HF_DATASET_DATA_MULTILABEL = Dataset.from_pandas(TEST_DATA_FRAME_DATA_MULTILABEL) dm = TextClassificationData.from_hf_datasets( @@ -489,7 +489,7 @@ def test_from_hf_datasets_multilabel(): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_lists(): dm = TextClassificationData.from_lists( train_data=TEST_LIST_DATA, @@ -519,7 +519,7 @@ def test_from_lists(): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_lists_multilabel(): dm = TextClassificationData.from_lists( train_data=TEST_LIST_DATA, @@ -550,7 +550,7 @@ def test_from_lists_multilabel(): assert isinstance(batch[DataKeys.INPUT][0], str) -@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") +@pytest.mark.skipif(_TOPIC_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): TextClassificationData.from_json("sentence", "lab", train_file="", batch_size=1) diff --git a/tests/text/classification/test_data_model_integration.py b/tests/text/classification/test_data_model_integration.py index e9b7827ea3..b4ee5d16c8 100644 --- a/tests/text/classification/test_data_model_integration.py +++ b/tests/text/classification/test_data_model_integration.py @@ -17,7 +17,7 @@ import pytest from flash.core.trainer import Trainer -from flash.core.utilities.imports import _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import TextClassificationData, TextClassifier TEST_BACKBONE = "prajjwal1/bert-tiny" # tiny model for testing @@ -36,7 +36,7 @@ def csv_data(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_classification(tmpdir): csv_path = csv_data(tmpdir) diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 05bf82b54e..26c71b7778 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -20,7 +20,7 @@ import flash from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING, _TORCH_ORT_AVAILABLE +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TEXT_AVAILABLE, _TORCH_ORT_AVAILABLE from flash.text import TextClassifier from flash.text.ort_callback import ORTCallback from tests.helpers.boring_model import BoringModel @@ -34,8 +34,8 @@ class TestTextClassifier(TaskTester): task_args = (2,) task_kwargs = {"backbone": TEST_BACKBONE} cli_command = "text_classification" - is_testing = _TEXT_TESTING - is_available = _TEXT_AVAILABLE + is_testing = _TOPIC_TEXT_AVAILABLE + is_available = _TOPIC_TEXT_AVAILABLE scriptable = False @@ -106,7 +106,7 @@ def test_ort_callback_fails_no_model(self, tmpdir): trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TextClassifier(2, backbone=TEST_BACKBONE) diff --git a/tests/text/embedding/test_model.py b/tests/text/embedding/test_model.py index f295f97e17..87a0ce513f 100644 --- a/tests/text/embedding/test_model.py +++ b/tests/text/embedding/test_model.py @@ -19,7 +19,7 @@ from torch import Tensor import flash -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import TextClassificationData, TextEmbedder from tests.helpers.task_tester import TaskTester @@ -39,8 +39,8 @@ class TestTextEmbedder(TaskTester): task = TextEmbedder task_kwargs = {"backbone": TEST_BACKBONE} - is_testing = _TEXT_TESTING - is_available = _TEXT_AVAILABLE + is_testing = _TOPIC_TEXT_AVAILABLE + is_available = _TOPIC_TEXT_AVAILABLE scriptable = False @@ -54,7 +54,7 @@ def check_forward_output(self, output: Any): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_predict(tmpdir): datamodule = TextClassificationData.from_lists(predict_data=predict_data, batch_size=4) model = TextEmbedder(backbone=TEST_BACKBONE) diff --git a/tests/text/question_answering/test_data.py b/tests/text/question_answering/test_data.py index afe80479d4..90ed1b3999 100644 --- a/tests/text/question_answering/test_data.py +++ b/tests/text/question_answering/test_data.py @@ -18,7 +18,7 @@ import pandas as pd import pytest -from flash.core.utilities.imports import _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import QuestionAnsweringData TEST_CSV_DATA = { @@ -100,7 +100,7 @@ def json_data_with_field(tmpdir, data): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = QuestionAnsweringData.from_csv( @@ -117,7 +117,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = QuestionAnsweringData.from_csv( @@ -141,7 +141,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir, TEST_JSON_DATA) dm = QuestionAnsweringData.from_json( @@ -158,7 +158,7 @@ def test_from_json(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json_with_field(tmpdir): json_path = json_data_with_field(tmpdir, TEST_JSON_DATA) dm = QuestionAnsweringData.from_json( @@ -176,7 +176,7 @@ def test_from_json_with_field(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_wrong_keys_and_types(tmpdir): TEST_CSV_DATA.pop("answer_text") with pytest.raises(KeyError): diff --git a/tests/text/question_answering/test_model.py b/tests/text/question_answering/test_model.py index 66446d736e..144fa1163c 100644 --- a/tests/text/question_answering/test_model.py +++ b/tests/text/question_answering/test_model.py @@ -18,7 +18,7 @@ import torch from torch import Tensor -from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import QuestionAnsweringTask from tests.helpers.task_tester import TaskTester @@ -29,8 +29,8 @@ class TestQuestionAnsweringTask(TaskTester): task = QuestionAnsweringTask task_kwargs = {"backbone": TEST_BACKBONE} cli_command = "question_answering" - is_testing = _TEXT_TESTING - is_available = _TEXT_AVAILABLE + is_testing = _TOPIC_TEXT_AVAILABLE + is_available = _TOPIC_TEXT_AVAILABLE scriptable = False traceable = False @@ -66,7 +66,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_modules_to_freeze(): model = QuestionAnsweringTask(backbone=TEST_BACKBONE) assert model.modules_to_freeze() is model.model.distilbert diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 4644453778..9241acd395 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -17,7 +17,7 @@ import pytest from flash import DataKeys -from flash.core.utilities.imports import _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import SummarizationData TEST_CSV_DATA = """input,target @@ -58,7 +58,7 @@ def json_data_with_field(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = SummarizationData.from_csv("input", "target", train_file=csv_path, batch_size=1) @@ -68,7 +68,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = SummarizationData.from_csv( @@ -89,7 +89,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = SummarizationData.from_json( @@ -104,7 +104,7 @@ def test_from_json(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json_with_field(tmpdir): json_path = json_data_with_field(tmpdir) dm = SummarizationData.from_json( diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index dddca4b269..bf9c6f4ede 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -19,7 +19,7 @@ from torch import Tensor from flash import DataKeys -from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TEXT_AVAILABLE from flash.text import SummarizationTask from tests.helpers.task_tester import TaskTester @@ -33,8 +33,8 @@ class TestSummarizationTask(TaskTester): "tokenizer_kwargs": {"src_lang": "en_XX", "tgt_lang": "en_XX"}, } cli_command = "summarization" - is_testing = _TEXT_TESTING - is_available = _TEXT_AVAILABLE + is_testing = _TOPIC_TEXT_AVAILABLE + is_available = _TOPIC_TEXT_AVAILABLE scriptable = False @@ -61,7 +61,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SummarizationTask(TEST_BACKBONE) diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index 1a361c509b..45663c9198 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -17,7 +17,7 @@ import pytest from flash import DataKeys -from flash.core.utilities.imports import _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.text import TranslationData TEST_CSV_DATA = """input,target @@ -58,7 +58,7 @@ def json_data_with_field(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv( @@ -73,7 +73,7 @@ def test_from_csv(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_files(tmpdir): csv_path = csv_data(tmpdir) dm = TranslationData.from_csv( @@ -94,7 +94,7 @@ def test_from_files(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json(tmpdir): json_path = json_data(tmpdir) dm = TranslationData.from_json( @@ -109,7 +109,7 @@ def test_from_json(tmpdir): @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_from_json_with_field(tmpdir): json_path = json_data_with_field(tmpdir) dm = TranslationData.from_json( diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index b6ef8cbe42..b26e767d9e 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -19,7 +19,7 @@ from torch import Tensor from flash import DataKeys -from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _TOPIC_SERVE_AVAILABLE, _TOPIC_TEXT_AVAILABLE from flash.text import TranslationTask from tests.helpers.task_tester import TaskTester @@ -33,8 +33,8 @@ class TestTranslationTask(TaskTester): "tokenizer_kwargs": {"src_lang": "en_XX", "tgt_lang": "ro_RO"}, } cli_command = "translation" - is_testing = _TEXT_TESTING - is_available = _TEXT_AVAILABLE + is_testing = _TOPIC_TEXT_AVAILABLE + is_available = _TOPIC_TEXT_AVAILABLE scriptable = False @@ -61,7 +61,7 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = TranslationTask(TEST_BACKBONE) diff --git a/tests/video/classification/test_data.py b/tests/video/classification/test_data.py index e07cd21d5e..c6cd9d2b0d 100644 --- a/tests/video/classification/test_data.py +++ b/tests/video/classification/test_data.py @@ -16,10 +16,10 @@ import pytest import torch -from flash.core.utilities.imports import _VIDEO_AVAILABLE +from flash.core.utilities.imports import _TOPIC_VIDEO_AVAILABLE from flash.video.classification.data import VideoClassificationData -if _VIDEO_AVAILABLE: +if _TOPIC_VIDEO_AVAILABLE: from pytorchvideo.data.utils import thwc_to_cthw @@ -35,7 +35,7 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int): def temp_encoded_tensors(num_frames: int, height=10, width=10): - if not _VIDEO_AVAILABLE: + if not _TOPIC_VIDEO_AVAILABLE: return torch.randint(size=(3, num_frames, height, width), low=0, high=255) data = create_dummy_video_frames(num_frames, height, width) return thwc_to_cthw(data).to(torch.float32) @@ -61,7 +61,7 @@ def _check_frames(data, expected_frames_count: Union[list, int]): ), f"Expected video sample {idx} to have {expected_frames_count[idx]} frames but got {sample.shape[1]} frames" -@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") @pytest.mark.parametrize( "input_data, input_targets, expected_frames_count", [ @@ -79,7 +79,7 @@ def test_load_data_from_tensors(input_data, input_targets, expected_frames_count _check_frames(data=datamodule.train_dataset.data, expected_frames_count=expected_frames_count) -@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") @pytest.mark.parametrize( "input_data, input_targets, error_type, match", [ diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index f182c01eaa..9523e0bb19 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -25,7 +25,7 @@ import flash from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _VIDEO_AVAILABLE, _VIDEO_TESTING +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_VIDEO_AVAILABLE from flash.video import VideoClassificationData, VideoClassifier from tests.helpers.task_tester import TaskTester from tests.video.classification.test_data import create_dummy_video_frames, temp_encoded_tensors @@ -33,7 +33,7 @@ if _FIFTYONE_AVAILABLE: import fiftyone as fo -if _VIDEO_AVAILABLE: +if _TOPIC_VIDEO_AVAILABLE: import torchvision.io as io from pytorchvideo.data.utils import thwc_to_cthw @@ -43,8 +43,8 @@ class TestVideoClassifier(TaskTester): task_args = (2,) task_kwargs = {"pretrained": False, "backbone": "slow_r50"} cli_command = "video_classification" - is_testing = _VIDEO_TESTING - is_available = _VIDEO_AVAILABLE + is_testing = _TOPIC_VIDEO_AVAILABLE + is_available = _TOPIC_VIDEO_AVAILABLE scriptable = False @@ -142,7 +142,7 @@ def mock_encoded_video_dataset_folder(tmpdir): yield str(tmp_dir), video_duration -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_folder(tmpdir): with mock_encoded_video_dataset_folder(tmpdir) as (mock_folder, total_duration): half_duration = total_duration / 2 - 1e-9 @@ -165,7 +165,7 @@ def test_video_classifier_finetune_from_folder(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_files(tmpdir): with mock_video_data_frame() as (mock_data_frame, total_duration): half_duration = total_duration / 2 - 1e-9 @@ -189,7 +189,7 @@ def test_video_classifier_finetune_from_files(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_data_frame(tmpdir): with mock_video_data_frame() as (mock_data_frame, total_duration): half_duration = total_duration / 2 - 1e-9 @@ -214,7 +214,7 @@ def test_video_classifier_finetune_from_data_frame(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_tensors(tmpdir): mock_tensors = temp_encoded_tensors(num_frames=5) datamodule = VideoClassificationData.from_tensors( @@ -237,7 +237,7 @@ def test_video_classifier_finetune_from_tensors(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_video_classifier_predict_from_tensors(tmpdir): mock_tensors = temp_encoded_tensors(num_frames=5) datamodule = VideoClassificationData.from_tensors( @@ -265,7 +265,7 @@ def test_video_classifier_predict_from_tensors(tmpdir): assert predictions[0][0] in datamodule.labels -@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") def test_video_classifier_finetune_from_csv(tmpdir): with mock_video_csv_file(tmpdir) as (mock_csv, total_duration): half_duration = total_duration / 2 - 1e-9 @@ -290,7 +290,7 @@ def test_video_classifier_finetune_from_csv(tmpdir): trainer.finetune(model, datamodule=datamodule) -@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") +@pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.") def test_video_classifier_finetune_fiftyone(tmpdir): with mock_encoded_video_dataset_folder(tmpdir) as ( From 50139fc33df1c7a5ea2d897d14287fd6f50f8098 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Apr 2023 07:31:51 +0200 Subject: [PATCH 18/39] [pre-commit.ci] pre-commit suggestions (#1540) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/PyCQA/docformatter: v1.5.0 → v1.5.1](https://github.com/PyCQA/docformatter/compare/v1.5.0...v1.5.1) - [github.com/psf/black: 23.1.0 → 23.3.0](https://github.com/psf/black/compare/23.1.0...23.3.0) - [github.com/charliermarsh/ruff-pre-commit: v0.0.240 → v0.0.260](https://github.com/charliermarsh/ruff-pre-commit/compare/v0.0.240...v0.0.260) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7573e42d35..607efa3931 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: - id: nbstripout - repo: https://github.com/PyCQA/docformatter - rev: v1.5.0 + rev: v1.5.1 hooks: - id: docformatter args: @@ -57,7 +57,7 @@ repos: - "--wrap-descriptions=120" - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.3.0 hooks: - id: black name: Format code @@ -77,7 +77,7 @@ repos: - "--skip-errors" - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.240 + rev: v0.0.260 hooks: - id: ruff args: ["--fix"] From 0db86c13e3a0fccd6a7f5ab4303fb547f60059ce Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 5 May 2023 18:05:35 +0200 Subject: [PATCH 19/39] precommit: update (#1546) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +- .../image_classification_imagenette_mini.py | 1 - src/flash/core/serve/component.py | 16 ++-- src/flash/core/serve/core.py | 2 +- src/flash/core/serve/dag/task.py | 4 +- src/flash/core/serve/utils.py | 2 +- src/flash/core/utilities/lightning_cli.py | 4 +- .../integrations/learn2learn.py | 1 - tests/core/serve/test_dag/test_order.py | 74 ++++++++++--------- 9 files changed, 57 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 607efa3931..7e63f80f90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.3.2 hooks: - id: pyupgrade args: [--py37-plus] @@ -48,7 +48,7 @@ repos: - id: nbstripout - repo: https://github.com/PyCQA/docformatter - rev: v1.5.1 + rev: v1.6.5 hooks: - id: docformatter args: @@ -77,7 +77,7 @@ repos: - "--skip-errors" - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.260 + rev: v0.0.264 hooks: - id: ruff args: ["--fix"] diff --git a/examples/integrations/learn2learn/image_classification_imagenette_mini.py b/examples/integrations/learn2learn/image_classification_imagenette_mini.py index ba84184174..b4d8603ef1 100644 --- a/examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -13,7 +13,6 @@ # limitations under the License. # adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154 - """## Train file https://www.dropbox.com/s/9g8c6w345s2ek03/mini-imagenet-cache-train.pkl?dl=1 ## Validation File diff --git a/src/flash/core/serve/component.py b/src/flash/core/serve/component.py index 2c7dad8e77..d7c6a9ee92 100644 --- a/src/flash/core/serve/component.py +++ b/src/flash/core/serve/component.py @@ -196,15 +196,13 @@ def __call__(cls, *args, **kwargs): class ModelComponent(metaclass=FlashServeMeta): """Represents a computation which is decorated by `@expose`. - A component is how we represent the main unit of work; it is a set of - evaluations which involve some input being passed through some set of - functions to generate some set of outputs. - - To specify a component, we record things like: its name, source file - assets, configuration args, model source assets, etc. The - specification must be YAML serializable and loadable to/from a fully - initialized instance. It must contain the minimal set of information - necessary to find and initialize its dependencies (assets) and itself. + A component is how we represent the main unit of work; it is a set of evaluations which involve some input being + passed through some set of functions to generate some set of outputs. + + To specify a component, we record things like: its name, source file assets, configuration args, model source + assets, etc. The specification must be YAML serializable and loadable to/from a fully initialized instance. It + must contain the minimal set of information necessary to find and initialize its dependencies (assets) and + itself. """ _flashserve_meta_: Optional[Union[BoundMeta, UnboundMeta]] = None diff --git a/src/flash/core/serve/core.py b/src/flash/core/serve/core.py index 06447f9aa9..6ebe5bb415 100644 --- a/src/flash/core/serve/core.py +++ b/src/flash/core/serve/core.py @@ -206,7 +206,7 @@ def __str__(self): return f"{self.component_uid}.{self.position}.{self.name}" def __terminate_invalid_connection_request(self, other: "Parameter", dunder_meth_called: str) -> None: - """verify that components can be composed. + """Verify that components can be composed. Parameters ---------- diff --git a/src/flash/core/serve/dag/task.py b/src/flash/core/serve/dag/task.py index d1bb72c4f7..d1903d8819 100644 --- a/src/flash/core/serve/dag/task.py +++ b/src/flash/core/serve/dag/task.py @@ -28,8 +28,8 @@ def ishashable(x): def istask(x): - """Is x a runnable task? - A task is a tuple with a callable first argument + """Is x a runnable task? A task is a tuple with a callable first argument. + Examples -------- >>> istask((inc, 1)) diff --git a/src/flash/core/serve/utils.py b/src/flash/core/serve/utils.py index afb709b912..472493e47c 100644 --- a/src/flash/core/serve/utils.py +++ b/src/flash/core/serve/utils.py @@ -6,7 +6,7 @@ def fn_outputs_to_keyed_map(serialize_fn_out_keys, fn_output) -> Dict[str, Any]: - """convert outputs of a function to a dict of `{result_name: values}` + """Convert outputs of a function to a dict of `{result_name: values}` accepts function outputs which are sequence, dict, or object. """ diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py index 1d89796a1c..864e047642 100644 --- a/src/flash/core/utilities/lightning_cli.py +++ b/src/flash/core/utilities/lightning_cli.py @@ -63,8 +63,8 @@ class LightningArgumentParser(ArgumentParser): def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. - For full details of accepted arguments see `ArgumentParser.__init__ - `_. + For full details of accepted arguments see + `ArgumentParser.__init__ `_. """ super().__init__(*args, parse_as_dict=parse_as_dict, **kwargs) self.add_argument( diff --git a/src/flash/image/classification/integrations/learn2learn.py b/src/flash/image/classification/integrations/learn2learn.py index bdeb9f3096..8912a66a37 100644 --- a/src/flash/image/classification/integrations/learn2learn.py +++ b/src/flash/image/classification/integrations/learn2learn.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Note: This file will be deleted once https://github.com/learnables/learn2learn/pull/257/files is merged within Learn2Learn. diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py index 49f9bd5408..3ede55d64b 100644 --- a/tests/core/serve/test_dag/test_order.py +++ b/tests/core/serve/test_dag/test_order.py @@ -41,10 +41,11 @@ def test_ordering_keeps_groups_together(abcde): def test_avoid_broker_nodes(abcde): r"""Testing structure. - b0 b1 b2 + Example:: - | \ / - a0 a1 + b0 b1 b2 + | \ / + a0 a1 a0 should be run before a1 """ @@ -288,11 +289,13 @@ def test_prefer_short_dependents(abcde): def test_run_smaller_sections(abcde): r"""Testing structure. - aa - / | - b d bb dd - / \ /| | / - a c e cc + Example:: + + aa + / | + b d bb dd + / \ /| | / + a c e cc Prefer to run acb first because then we can get that out of the way """ @@ -379,9 +382,11 @@ def _(*args): def test_nearest_neighbor(abcde): r"""Testing structure. - a1 a2 a3 a4 a5 a6 a7 a8 a9 - \ | / \ | / \ | / \ | / - b1 b2 b3 b4 + Example:: + + a1 a2 a3 a4 a5 a6 a7 a8 a9 + \ | / \ | / \ | / \ | / + b1 b2 b3 b4 Want to finish off a local group before moving on. This is difficult because all groups are connected. @@ -512,14 +517,15 @@ def test_prefer_short_ancestor(abcde): def test_map_overlap(abcde): r"""Testing structure. - b1 b3 b5. + Example:: - |\ / | \ / | - c1 c2 c3 c4 c5 - |/ | \ | / | \| - d1 d2 d3 d4 d5 - | | | - e1 e2 e5 + b1 b3 b5. + |\ / | \ / | + c1 c2 c3 c4 c5 + |/ | \ | / | \| + d1 d2 d3 d4 d5 + | | | + e1 e2 e5 Want to finish b1 before we start on e5 """ @@ -667,21 +673,23 @@ def test_order_empty(): def test_switching_dependents(abcde): r"""Testing structure. - a7 a8 <-- do these last - | / - a6 e6 - | / - a5 c5 d5 e5 - | | / / - a4 c4 d4 e4 - | \ | / / - a3 b3---/ - | - a2 - | - a1 - | - a0 <-- start here + Example:: + + a7 a8 <-- do these last + | / + a6 e6 + | / + a5 c5 d5 e5 + | | / / + a4 c4 d4 e4 + | \ | / / + a3 b3---/ + | + a2 + | + a1 + | + a0 <-- start here Test that we are able to switch to better dependents. In this graph, we expect to start at a0. To compute a4, we need to compute b3. From 13545c083c519062ea3d613d3159b49d101c3649 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 5 May 2023 19:59:52 +0200 Subject: [PATCH 20/39] req: bump `pytorch-tabular>=1.0` (#1545) --- .github/workflows/ci-testing.yml | 2 +- requirements.txt | 2 +- requirements/datatype_tabular.txt | 5 +++-- src/flash/core/data/base_viz.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 045bc72f37..3fbed1a27e 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -105,7 +105,7 @@ jobs: env: SYSTEM_VERSION_COMPAT: 1 run: | - pip --version + python -m pip install "pip==22.2.1" pip install cython "torch>=1.7.1" -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install .[$EXTRAS,test] --upgrade --prefer-binary --find-links https://download.pytorch.org/whl/cpu/torch_stable.html diff --git a/requirements.txt b/requirements.txt index 339b2bbead..c2dbe25984 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ jsonargparse[signatures] >=3.17.0, <=4.9.0 click >=7.1.2, <=8.1.3 protobuf <=3.20.1 fsspec[http] >=2022.5.0,<=2022.7.1 -lightning-utilities >=0.4.1 +lightning-utilities >=0.3.0 diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index e81f4b143a..d00ab4b193 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -2,6 +2,7 @@ scikit-learn <=1.2.0 pytorch-forecasting>=0.9.0, <=0.10.3 -pytorch-tabular==0.7.0 -torchmetrics<0.8.0 # pytorch-tabular pins PL so we force a compatible TM version +# pytorch-tabular>=1.0.2, <1.0.3 # pending requirements resolving +https://github.com/Lightning-Sandbox/pytorch_tabular/archive/refs/heads/req/replace.zip +torchmetrics>=0.10.0 omegaconf<=2.1.1, <=2.1.1 diff --git a/src/flash/core/data/base_viz.py b/src/flash/core/data/base_viz.py index 3d4f7d5922..9f8c90cae7 100644 --- a/src/flash/core/data/base_viz.py +++ b/src/flash/core/data/base_viz.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, List, Set, Tuple -from lightning_utilities.core.overrides import is_overridden +from pytorch_lightning.utilities.model_helpers import is_overridden from flash.core.data.callback import BaseDataFetcher from flash.core.data.utils import _CALLBACK_FUNCS From c8e433de7d301d16d56aa9d6522babe58551104d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Sat, 6 May 2023 16:01:03 +0200 Subject: [PATCH 21/39] CI + testing req. (#1488) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .azure/gpu-example-tests.yml | 2 - .azure/gpu-special-tests.yml | 18 +---- .azure/template-examples.yml | 31 +------- .github/workflows/ci-testing.yml | 72 +++++++------------ requirements.txt | 2 +- requirements/datatype_audio.txt | 8 +-- requirements/datatype_graph.txt | 6 +- requirements/datatype_image.txt | 16 +++-- ...xtras_baal.txt => datatype_image_baal.txt} | 5 +- requirements/datatype_image_extras.txt | 21 +++--- requirements/datatype_image_segm.txt | 4 ++ requirements/datatype_image_vissl.txt | 4 ++ requirements/datatype_pointcloud.txt | 7 +- requirements/datatype_tabular.txt | 12 ++-- requirements/datatype_text.txt | 10 +-- requirements/datatype_video.txt | 10 +-- requirements/datatype_video_extras.txt | 3 - requirements/docs.txt | 2 +- requirements/serve.txt | 14 ++-- requirements/test.txt | 10 +-- requirements/testing_audio.txt | 10 +++ requirements/testing_core.txt | 0 requirements/testing_graph.txt | 5 ++ requirements/testing_image.txt | 2 + requirements/testing_pointcloud.txt | 0 requirements/testing_serve.txt | 7 ++ requirements/testing_tabular.txt | 0 requirements/testing_text.txt | 0 requirements/testing_video.txt | 0 setup.py | 17 ++--- src/flash/core/utilities/flash_cli.py | 2 +- src/flash/core/utilities/imports.py | 1 - tests/core/test_model.py | 52 +++++--------- tests/examples/test_scripts.py | 1 + tests/image/semantic_segm/test_backbones.py | 9 +-- tests/image/semantic_segm/test_data.py | 11 +-- tests/image/semantic_segm/test_heads.py | 16 ++--- tests/image/semantic_segm/test_model.py | 18 ++--- tests/image/semantic_segm/test_output.py | 14 ++-- 39 files changed, 177 insertions(+), 245 deletions(-) rename requirements/{datatype_image_extras_baal.txt => datatype_image_baal.txt} (66%) create mode 100644 requirements/datatype_image_segm.txt create mode 100644 requirements/datatype_image_vissl.txt delete mode 100644 requirements/datatype_video_extras.txt create mode 100644 requirements/testing_audio.txt create mode 100644 requirements/testing_core.txt create mode 100644 requirements/testing_graph.txt create mode 100644 requirements/testing_image.txt create mode 100644 requirements/testing_pointcloud.txt create mode 100644 requirements/testing_serve.txt create mode 100644 requirements/testing_tabular.txt create mode 100644 requirements/testing_text.txt create mode 100644 requirements/testing_video.txt diff --git a/.azure/gpu-example-tests.yml b/.azure/gpu-example-tests.yml index 3820f48254..51f1826dca 100644 --- a/.azure/gpu-example-tests.yml +++ b/.azure/gpu-example-tests.yml @@ -12,8 +12,6 @@ jobs: parameters: domains: - "image" - - "icevision" - - "vissl" - "text" - "tabular" - "video" diff --git a/.azure/gpu-special-tests.yml b/.azure/gpu-special-tests.yml index f0cc5a3392..5b3060ac6a 100644 --- a/.azure/gpu-special-tests.yml +++ b/.azure/gpu-special-tests.yml @@ -53,7 +53,7 @@ jobs: - bash: | # python -m pip install "pip==20.1" - pip install '.[image,test]' learn2learn + pip install '.[image,test]' -r requirements/testing_image.txt pip list env: FREEZE_REQUIREMENTS: 1 @@ -70,19 +70,3 @@ jobs: python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure ls -l displayName: 'Statistics' - - - task: PublishTestResults@2 - displayName: 'Publish test results' - inputs: - testResultsFiles: '$(Build.StagingDirectory)/test-results.xml' - testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)' - condition: succeededOrFailed() - - - task: PublishCodeCoverageResults@1 - displayName: 'Publish coverage report' - inputs: - codeCoverageTool: 'cobertura' - summaryFileLocation: 'coverage.xml' - reportDirectory: '$(Build.SourcesDirectory)/htmlcov' - testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)' - condition: succeededOrFailed() diff --git a/.azure/template-examples.yml b/.azure/template-examples.yml index 5e6fb15a5a..ab9445c170 100644 --- a/.azure/template-examples.yml +++ b/.azure/template-examples.yml @@ -14,7 +14,7 @@ jobs: # this need to have installed docker in the base image... container: # base ML image: mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04 - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.10" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11" # image: "pytorch/pytorch:1.8.1-cuda11.0-cudnn8-runtime" options: "-it --rm --gpus=all --shm-size=16g" @@ -40,39 +40,14 @@ jobs: displayName: 'Sanity check' - bash: | - # python -m pip install "pip==20.1" - if [ "${{dom}}" == "icevision" ]; then - pip install '.[image]' icevision effdet icedata; - elif [ "${{dom}}" == "vissl" ]; then - pip install '.[image]'; - else - pip install '.[${{dom}}]'; - fi - pip install '.[test]' --upgrade-strategy only-if-needed + pip install '.[${{dom}},test]' -r requirements/testing_${{dom}}.txt pip list env: FREEZE_REQUIREMENTS: 1 displayName: 'Install dependencies' - bash: | - pip uninstall -y opencv-python opencv-python-headless - pip install opencv-python-headless==4.5.5.64 - displayName: 'Install OpenCV dependencies' - condition: eq('${{ dom }}', 'icevision') - - - bash: | - pip install fairscale - pip install git+https://github.com/facebookresearch/ClassyVision.git - pip install git+https://github.com/facebookresearch/vissl.git - displayName: 'Install VISSL dependencies' - condition: eq('${{ dom }}', 'vissl') - - - bash: | - python -c "import torch; print(f'found GPUs: {torch.cuda.device_count()}')" - python -m coverage run --source flash -m pytest \ - tests/examples/test_scripts.py \ - tests/image/embedding/test_model.py \ - -v --junitxml=$(Build.StagingDirectory)/test-results.xml --durations=30 + python -m coverage run --source flash -m pytest --durations=30 env: FLASH_TEST_TOPIC: ${{ dom }} displayName: 'Testing' diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 3fbed1a27e..da2b686559 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -32,14 +32,15 @@ jobs: exclude: # Skip if torch<1.8 and py3.9 on Linux: https://github.com/pytorch/pytorch/issues/50014 - { python-version: 3.9, requires: 'oldest' } - - { os: macOS-12, requires: 'oldest' } - - { os: windows-2022, requires: 'oldest' } + - { os: 'macOS-12', requires: 'oldest' } + - { os: 'windows-2022', requires: 'oldest' } include: - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'core', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extras']} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extras_baal']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extra']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_baal']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_segm']} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_vissl']} - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'video', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'video', extra: ['video_extras']} - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'tabular', extra: []} - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'text', extra: []} - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'pointcloud', extra: []} @@ -68,31 +69,24 @@ jobs: # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 - name: Setup macOS if: runner.os == 'macOS' - run: | - brew update - brew install libomp # https://github.com/pytorch/pytorch/issues/20030 + run: brew install libomp openblas lapack - - name: Install graphviz - if: contains( matrix.topic , 'serve' ) - run: sudo apt-get install graphviz + - name: Setup Ubuntu + if: runner.os == 'Linux' + run: sudo apt-get install -y libsndfile1 graphviz - name: Set min. dependencies if: matrix.requires == 'oldest' run: | - fname = 'requirements.txt' - ignore = ['pandas', 'torchmetrics'] - lines = [line if any([line.startswith(package) for package in ignore]) else line.replace('>', '=') for line in open(fname).readlines()] - open(fname, 'w').writelines(lines) + import glob, os + # FixMe: shall be minimal for ALL dependencies not only base + # files = glob.glob(os.path.join("requirements", "*.txt")) + ['requirements.txt'] + files = ['requirements.txt'] + for fname in files: + lines = [line.replace('>=', '==') for line in open(fname).readlines()] + open(fname, 'w').writelines(lines) shell: python - - name: Install graph test dependencies - if: contains( matrix.topic , 'graph' ) - run: | - pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - pip install torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cpu.html - pip install torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cpu.html - pip install torch-cluster -f https://data.pyg.org/whl/torch-1.11.0+cpu.html - - name: Adjust extras run: | import os @@ -101,32 +95,16 @@ jobs: gh_env.write(f"EXTRAS={','.join(extras)}") shell: python - - name: Install package & dependencies + - name: Install dependencies env: - SYSTEM_VERSION_COMPAT: 1 + TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html run: | python -m pip install "pip==22.2.1" - pip install cython "torch>=1.7.1" -f https://download.pytorch.org/whl/cpu/torch_stable.html - pip install .[$EXTRAS,test] --upgrade --prefer-binary --find-links https://download.pytorch.org/whl/cpu/torch_stable.html - - - name: Install vissl - if: contains( matrix.topic , 'image_extras' ) - run: | - pip install git+https://github.com/facebookresearch/ClassyVision.git - pip install git+https://github.com/facebookresearch/vissl.git - - - name: Install serve test dependencies - if: contains( matrix.topic , 'serve' ) - run: | - sudo apt-get install libsndfile1 - pip install '.[all,audio]' icevision sahi==0.8.19 effdet --upgrade - - - name: Install audio test dependencies - if: contains( matrix.topic , 'audio' ) - run: | - sudo apt-get install libsndfile1 - pip install matplotlib - pip install '.[audio,image]' torch==1.11.0 --upgrade + pip install numpy Cython "torch>=1.7.1" -f $TORCH_URL + pip install .[$EXTRAS,test] \ + -r requirements/testing_${{ matrix.topic }}.txt \ + --upgrade --prefer-binary -f $TORCH_URL + pip list - name: Cache datasets uses: actions/cache@v3 @@ -137,10 +115,8 @@ jobs: - name: Tests env: - FLASH_TEST_TOPIC: ${{ join(matrix.topic,',') }} FIFTYONE_DO_NOT_TRACK: true run: | - pip list # FixMe: include doctests for src/ coverage run --source flash -m pytest \ tests/core \ diff --git a/requirements.txt b/requirements.txt index c2dbe25984..be60c76589 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup -packaging +packaging <23.0 setuptools <=59.5.0 # Prevent install bug with tensorboard numpy <1.24 # strict - freeze for using np.long torch >=1.7.1 diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 5844520660..c28ac77842 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1,7 +1,7 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup torchaudio <=0.13.1 torchvision <=0.14.1 -librosa>=0.8.1, <=0.9.2 -transformers>=4.13.0, <=4.25.1 -datasets>=1.16.1, <=2.8.0 +librosa >=0.8.1, <=0.9.2 +transformers >=4.13.0, <=4.25.1 +datasets >=1.16.1, <=2.8.0 diff --git a/requirements/datatype_graph.txt b/requirements/datatype_graph.txt index ff65a1399c..fe223f68c8 100644 --- a/requirements/datatype_graph.txt +++ b/requirements/datatype_graph.txt @@ -1,8 +1,8 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup torch-scatter <=2.1.0 torch-sparse <=0.6.16 -torch-geometric>=2.0.0, <=2.2.0 +torch-geometric >=2.0.0, <=2.2.0 torch-cluster <=1.6.0 networkx <=2.8.8 -class-resolver>=0.3.2, <=0.3.10 +class-resolver >=0.3.2, <=0.3.10 diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index ea7550693a..faf62371df 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -1,12 +1,14 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup torchvision <=0.14.1 -timm>=0.4.5, <=0.4.12 -lightning-bolts>=0.3.3, <=0.6.0 -Pillow>=7.2, <=9.3.0 +timm >0.4.5, <=0.6.11 # effdet 0.3.0 depends on timm>=0.4.12 +lightning-bolts >0.3.3, <=0.6.0 +Pillow >7.1, <=9.3.0 albumentations <=1.3.0 -pystiche>=1.0.0, <=1.0.1 -segmentation-models-pytorch>=0.2.0, <=0.3.1 +pystiche >1.0.0, <=1.0.1 ftfy <=6.1.1 regex <=2022.10.31 -sahi <0.11 # strict - Fixes compatibility with icevision +sahi >=0.8.19, <0.11 # strict - Fixes compatibility with icevision + +icevision >0.8 +icedata <=0.5.1 # dead diff --git a/requirements/datatype_image_extras_baal.txt b/requirements/datatype_image_baal.txt similarity index 66% rename from requirements/datatype_image_extras_baal.txt rename to requirements/datatype_image_baal.txt index 655862dfd9..05f17d6913 100644 --- a/requirements/datatype_image_extras_baal.txt +++ b/requirements/datatype_image_baal.txt @@ -1,5 +1,4 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup # This is a separate file, as baal integration is affected by vissl installation (conflicts) -baal>=1.3.2, <=1.7.0 -icevision >=0.8 +baal >=1.3.2, <=1.7.0 diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index ff6e9c9c7e..11a8f19a8e 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -1,20 +1,19 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup matplotlib <=3.6.2 -fiftyone -classy_vision -vissl>=0.1.5 -icevision >=0.8 -sahi >=0.8.19,<0.11.0 -icedata -effdet -kornia>=0.5.1 -learn2learn; platform_system != "Windows" # dead -fastface +fiftyone <0.19.0 +classy-vision <=0.6 +effdet <=0.3.0 +kornia >0.5.1, <=0.6.9 +learn2learn <=0.1.7; platform_system != "Windows" # dead +fastface <=0.1.3 # dead fairscale # pin PL for testing, remove when fastface is updated pytorch-lightning <1.5.0 -torchmetrics<0.8.0 # pinned PL so we force a compatible TM version + +# pinned PL so we force a compatible TM version +torchmetrics<0.8.0 + # effdet had an issue with PL 1.12, and icevision doesn't support effdet's latest version yet (0.3.0) torch <1.12 diff --git a/requirements/datatype_image_segm.txt b/requirements/datatype_image_segm.txt new file mode 100644 index 0000000000..cf37ef2c0d --- /dev/null +++ b/requirements/datatype_image_segm.txt @@ -0,0 +1,4 @@ +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup + +# This is a separate file, as segmentation integration is affected by vissl installation (conflicts) +segmentation-models-pytorch >0.2.0, <=0.3.1 diff --git a/requirements/datatype_image_vissl.txt b/requirements/datatype_image_vissl.txt new file mode 100644 index 0000000000..1d196351d9 --- /dev/null +++ b/requirements/datatype_image_vissl.txt @@ -0,0 +1,4 @@ +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup + +# This is a separate file, as vissl integration is affected by baal installation (conflicts) +vissl >=0.1.5, <=0.1.6 # dead diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt index aa95459400..52666d7698 100644 --- a/requirements/datatype_pointcloud.txt +++ b/requirements/datatype_pointcloud.txt @@ -1,6 +1,7 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup open3d ==0.13 -torch ==1.7.1 -torchvision ==0.8.2 +# Version mismatch: Open3D needs PyTorch version 1.7.*, but version 1.8.1+cpu is installed! +torch >=1.7.0, <1.8.0 +torchvision >=0.8.0, <0.9.0 tensorboard <=2.11.0 diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index d00ab4b193..30ca9c29d5 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -1,8 +1,8 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup scikit-learn <=1.2.0 -pytorch-forecasting>=0.9.0, <=0.10.3 -# pytorch-tabular>=1.0.2, <1.0.3 # pending requirements resolving -https://github.com/Lightning-Sandbox/pytorch_tabular/archive/refs/heads/req/replace.zip -torchmetrics>=0.10.0 -omegaconf<=2.1.1, <=2.1.1 +pytorch-forecasting >=0.9.0, <=0.10.3 +# pytorch-tabular >=1.0.2, <1.0.3 # pending requirements resolving +https://github.com/manujosephv/pytorch_tabular/archive/refs/heads/main.zip +torchmetrics >=0.10.0 +omegaconf <=2.1.1, <=2.1.1 diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index 22da582b8c..786aa1f9d9 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,11 +1,11 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup torchvision <=0.14.1 -sentencepiece>=0.1.95, <=0.1.97 +sentencepiece >=0.1.95, <=0.1.97 filelock <=3.8.2 -transformers>=4.5, <=4.25.1 -torchmetrics[text]>=0.5.1, <0.11.0 -datasets>=1.8, <=2.8.0 +transformers >4.13.0, <=4.25.1 +torchmetrics[text] >0.5.0, <0.11.0 +datasets >2.0.0, <=2.8.0 sentence-transformers <=2.2.2 ftfy <=6.1.1 regex <=2022.10.31 diff --git a/requirements/datatype_video.txt b/requirements/datatype_video.txt index 9c125d4d34..49c6ca8fdd 100644 --- a/requirements/datatype_video.txt +++ b/requirements/datatype_video.txt @@ -1,6 +1,8 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup torchvision <=0.14.1 -Pillow>=7.2, <=9.3.0 -kornia>=0.5.1, <=0.6.9 -pytorchvideo==0.1.2 +Pillow >7.1, <=9.3.0 +kornia >=0.5.1, <=0.6.9 +pytorchvideo ==0.1.2 + +fiftyone <=0.18.0 diff --git a/requirements/datatype_video_extras.txt b/requirements/datatype_video_extras.txt deleted file mode 100644 index 1c10853c29..0000000000 --- a/requirements/datatype_video_extras.txt +++ /dev/null @@ -1,3 +0,0 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup - -fiftyone <=0.18.0 diff --git a/requirements/docs.txt b/requirements/docs.txt index 7cec7ee7a7..5fb32043ee 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup sphinx >=4.0, <5.0 myst-parser >=0.15 diff --git a/requirements/serve.txt b/requirements/serve.txt index 792fb2e20e..e581827880 100644 --- a/requirements/serve.txt +++ b/requirements/serve.txt @@ -1,14 +1,14 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup -pillow <=9.3.0 +pillow >7.1, <=9.3.0 pyyaml <=6.0 cytoolz <=0.12.1 graphviz <=0.20.1 tqdm <=4.64.1 -fastapi>=0.65.2, <=0.68.2 -pydantic>1.8.1, <=1.10.2 -starlette==0.14.2 -uvicorn[standard]>=0.12.0, <=0.20.0 +fastapi >=0.65.2, <=0.68.2 +pydantic >1.8.1, <=1.10.2 +starlette ==0.14.2 +uvicorn[standard] >=0.12.0, <=0.20.0 aiofiles <=22.1.0 -jinja2>=3.0.0, <3.1.0 +jinja2 >=3.0.0, <3.1.0 torchvision <=0.14.1 diff --git a/requirements/test.txt b/requirements/test.txt index 86953722af..bc1e3fb2a2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,10 +1,10 @@ -# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forecd in setup +# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup coverage[toml] -codecov>=2.1 -pytest>=6.0, <7.0 -pytest-doctestplus>=0.9.0 -pytest-rerunfailures>=10.0 +codecov >2.1 +pytest >7.2, <7.4 +pytest-doctestplus >0.9.0 +pytest-rerunfailures >10.0 pytest-forked pytest-mock diff --git a/requirements/testing_audio.txt b/requirements/testing_audio.txt new file mode 100644 index 0000000000..04365b9b5f --- /dev/null +++ b/requirements/testing_audio.txt @@ -0,0 +1,10 @@ +matplotlib +torch ==1.11.0 +torchaudio ==0.11.0 +torchvision ==0.12.0 + +timm >0.4.5, <=0.6.11 # effdet 0.3.0 depends on timm>=0.4.12 +lightning-bolts >=0.3.3, <=0.6.0 +Pillow >7.1, <=9.3.0 +albumentations <=1.3.0 +pystiche >1.0.0, <=1.0.1 diff --git a/requirements/testing_core.txt b/requirements/testing_core.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/testing_graph.txt b/requirements/testing_graph.txt new file mode 100644 index 0000000000..59cd09b99d --- /dev/null +++ b/requirements/testing_graph.txt @@ -0,0 +1,5 @@ +torch ==1.11.0 +torchvision ==0.12.0 + +-f https://download.pytorch.org/whl/cpu/torch_stable.html +-f https://data.pyg.org/whl/torch-1.11.0+cpu.html diff --git a/requirements/testing_image.txt b/requirements/testing_image.txt new file mode 100644 index 0000000000..9e6f6186d5 --- /dev/null +++ b/requirements/testing_image.txt @@ -0,0 +1,2 @@ +# https://github.com/facebookresearch/ClassyVision/archive/refs/heads/main.zip +# https://github.com/facebookresearch/vissl/archive/refs/heads/main.zip diff --git a/requirements/testing_pointcloud.txt b/requirements/testing_pointcloud.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/testing_serve.txt b/requirements/testing_serve.txt new file mode 100644 index 0000000000..95d1c9779a --- /dev/null +++ b/requirements/testing_serve.txt @@ -0,0 +1,7 @@ +sahi ==0.8.19 + +-r datatype_image.txt +-r datatype_video.txt +-r datatype_tabular.txt +-r datatype_text.txt +-r datatype_audio.txt diff --git a/requirements/testing_tabular.txt b/requirements/testing_tabular.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/testing_text.txt b/requirements/testing_text.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/testing_video.txt b/requirements/testing_video.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/setup.py b/setup.py index c7ddb9e911..87dc5d3727 100644 --- a/setup.py +++ b/setup.py @@ -125,19 +125,16 @@ def _expand_reqs(extras: dict, keys: list) -> list: def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict: _load_req = partial(_load_requirements, path_dir=path_dir) found_req_files = sorted(os.path.basename(p) for p in glob.glob(os.path.join(path_dir, "*.txt"))) + found_req_files = [p for p in found_req_files if not p.startswith("testing_")] # remove datatype prefix found_req_names = [os.path.splitext(req)[0].replace("datatype_", "") for req in found_req_files] # define basic and extra extras - extras_req = { - name: _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files) if "_" not in name - } - extras_req.update( - { - name: extras_req[name.split("_")[0]] + _load_req(file_name=fname) - for name, fname in zip(found_req_names, found_req_files) - if "_" in name - } - ) + extras_req = {name: _load_req(file_name=fname) for name, fname in zip(found_req_names, found_req_files)} + # extras_req.update({ + # name: extras_req[name.split("_")[0]] + _load_req(file_name=fname) + # for name, fname in zip(found_req_names, found_req_files) + # if "_" in name + # }) # some extra combinations extras_req["vision"] = _expand_reqs(extras_req, ["image", "video"]) extras_req["core"] = _expand_reqs(extras_req, ["image", "tabular", "text"]) diff --git a/src/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py index 75f34dd98f..5ceb592a32 100644 --- a/src/flash/core/utilities/flash_cli.py +++ b/src/flash/core/utilities/flash_cli.py @@ -22,8 +22,8 @@ import pytorch_lightning as pl from jsonargparse import ArgumentParser from jsonargparse.signatures import get_class_signature_functions +from lightning_utilities.core.overrides import is_overridden from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.model_helpers import is_overridden import flash from flash.core.data.data_module import DataModule diff --git a/src/flash/core/utilities/imports.py b/src/flash/core/utilities/imports.py index 07adb6f7ad..24f787899f 100644 --- a/src/flash/core/utilities/imports.py +++ b/src/flash/core/utilities/imports.py @@ -112,7 +112,6 @@ class Image: _PIL_AVAILABLE, _ALBUMENTATIONS_AVAILABLE, _PYSTICHE_AVAILABLE, - _SEGMENTATION_MODELS_AVAILABLE, ] ) _TOPIC_SERVE_AVAILABLE = all([_FASTAPI_AVAILABLE, _PYDANTIC_AVAILABLE, _CYTOOLZ_AVAILABLE, _UVICORN_AVAILABLE]) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 489d100662..8cb1fd5349 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -40,6 +40,7 @@ from flash.core.utilities.embedder import Embedder from flash.core.utilities.imports import ( _PL_GREATER_EQUAL_1_8_0, + _SEGMENTATION_MODELS_AVAILABLE, _TM_GREATER_EQUAL_0_10_0, _TOPIC_AUDIO_AVAILABLE, _TOPIC_CORE_AVAILABLE, @@ -213,74 +214,53 @@ def test_classification_task_trainer_predict(tmpdir): pytest.param( ImageClassifier, "0.7.0/image_classification_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_IMAGE_AVAILABLE, - reason="image packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image packages aren't installed"), ), pytest.param( SemanticSegmentation, "0.9.0/semantic_segmentation_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_IMAGE_AVAILABLE or not _TM_GREATER_EQUAL_0_10_0, - reason="image packages aren't installed", - ), + marks=[ + pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image packages aren't installed"), + pytest.mark.skipif( + not _SEGMENTATION_MODELS_AVAILABLE, reason="segmentation_models_pytorch package is not installed" + ), + pytest.mark.skipif(not _TM_GREATER_EQUAL_0_10_0, reason="TM compatibility"), + ], ), pytest.param( SpeechRecognition, "0.7.0/speech_recognition_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_AUDIO_AVAILABLE, - reason="audio packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio packages aren't installed"), ), pytest.param( TabularClassifier, "0.7.0/tabular_classification_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_TABULAR_AVAILABLE, - reason="tabular packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular packages aren't installed"), ), pytest.param( TextClassifier, "0.9.0/text_classification_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_TEXT_AVAILABLE, - reason="text packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text packages aren't installed"), ), pytest.param( SummarizationTask, "0.7.0/summarization_model_xsum.pt", - marks=pytest.mark.skipif( - not _TOPIC_TEXT_AVAILABLE, - reason="text packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text packages aren't installed"), ), pytest.param( TranslationTask, "0.7.0/translation_model_en_ro.pt", - marks=pytest.mark.skipif( - not _TOPIC_TEXT_AVAILABLE, - reason="text packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text packages aren't installed"), ), pytest.param( GraphClassifier, "0.7.0/graph_classification_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_GRAPH_AVAILABLE, - reason="graph packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph packages aren't installed"), ), pytest.param( GraphEmbedder, "0.7.0/graph_classification_model.pt", - marks=pytest.mark.skipif( - not _TOPIC_GRAPH_AVAILABLE, - reason="graph packages aren't installed", - ), + marks=pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph packages aren't installed"), ), ], ) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ae87587603..2f8202b252 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -79,6 +79,7 @@ marks=[ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), pytest.mark.skipif(not _ICEDATA_AVAILABLE, reason="icedata package isn't installed"), + pytest.mark.xfail(strict=False), # ToDo ], ), pytest.param( diff --git a/tests/image/semantic_segm/test_backbones.py b/tests/image/semantic_segm/test_backbones.py index 4b8fb7a7a7..ba05e83c7a 100644 --- a/tests/image/semantic_segm/test_backbones.py +++ b/tests/image/semantic_segm/test_backbones.py @@ -17,13 +17,8 @@ from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES -@pytest.mark.parametrize( - ["backbone"], - [ - pytest.param("resnet50", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - pytest.param("dpn131", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - ], -) +@pytest.mark.parametrize("backbone", ["resnet50", "dpn131"]) +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_semantic_segmentation_backbones_registry(backbone): backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)() assert backbone diff --git a/tests/image/semantic_segm/test_data.py b/tests/image/semantic_segm/test_data.py index c865285055..0c07e99fb0 100644 --- a/tests/image/semantic_segm/test_data.py +++ b/tests/image/semantic_segm/test_data.py @@ -12,6 +12,7 @@ _FIFTYONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, + _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, ) from flash.image import SemanticSegmentation, SemanticSegmentationData @@ -50,15 +51,14 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup _rand_labels(size, num_classes).save(label_file) +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") class TestSemanticSegmentationData: @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_smoke(): dm = SemanticSegmentationData(batch_size=1) assert dm is not None @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders(tmpdir): tmp_dir = Path(tmpdir) @@ -120,7 +120,6 @@ def test_from_folders(tmpdir): assert labels.shape == (2, 128, 128) @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders_different_extensions(tmpdir): tmp_dir = Path(tmpdir) @@ -182,7 +181,6 @@ def test_from_folders_different_extensions(tmpdir): assert labels.shape == (2, 128, 128) @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_folders_error(tmpdir): tmp_dir = Path(tmpdir) @@ -217,7 +215,6 @@ def test_from_folders_error(tmpdir): ) @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_files(tmpdir): tmp_dir = Path(tmpdir) @@ -276,7 +273,6 @@ def test_from_files(tmpdir): assert labels.shape == (2, 128, 128) @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") def test_from_files_warning(tmpdir): tmp_dir = Path(tmpdir) @@ -310,7 +306,6 @@ def test_from_files_warning(tmpdir): ) @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") def test_from_fiftyone(tmpdir): tmp_dir = Path(tmpdir) @@ -380,8 +375,8 @@ def test_from_fiftyone(tmpdir): assert imgs.shape == (2, 3, 128, 128) @staticmethod - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") + @pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_map_labels(tmpdir): tmp_dir = Path(tmpdir) diff --git a/tests/image/semantic_segm/test_heads.py b/tests/image/semantic_segm/test_heads.py index cb18b2bfd1..19472860c7 100644 --- a/tests/image/semantic_segm/test_heads.py +++ b/tests/image/semantic_segm/test_heads.py @@ -11,25 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import unittest.mock +import unittest import pytest import torch -from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE from flash.image.segmentation import SemanticSegmentation from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS -@pytest.mark.parametrize( - "head", - [ - pytest.param("fpn", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - pytest.param("deeplabv3", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - pytest.param("unet", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")), - ], -) +@pytest.mark.parametrize("head", ["fpn", "deeplabv3", "unet"]) +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_semantic_segmentation_heads_registry(head): img = torch.rand(1, 3, 32, 32) backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet50")(pretrained=False) @@ -43,7 +37,7 @@ def test_semantic_segmentation_heads_registry(head): assert res.shape[1] == 10 -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") @unittest.mock.patch("flash.image.segmentation.heads.smp") def test_pretrained_weights(mock_smp): mock_smp.create_model = unittest.mock.MagicMock() diff --git a/tests/image/semantic_segm/test_model.py b/tests/image/semantic_segm/test_model.py index 03069563a5..142e64c614 100644 --- a/tests/image/semantic_segm/test_model.py +++ b/tests/image/semantic_segm/test_model.py @@ -21,12 +21,13 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE +from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE from flash.image import SemanticSegmentation from flash.image.segmentation.data import SemanticSegmentationData from tests.helpers.task_tester import TaskTester +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") class TestSemanticSegmentation(TaskTester): task = SemanticSegmentation task_args = (2,) @@ -59,13 +60,13 @@ def example_test_sample(self): return self.example_train_sample -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_non_existent_backbone(): with pytest.raises(KeyError): SemanticSegmentation(2, "i am never going to implement this lol") -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_freeze(): model = SemanticSegmentation(2) model.freeze() @@ -73,7 +74,7 @@ def test_freeze(): assert p.requires_grad is False -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_unfreeze(): model = SemanticSegmentation(2) model.unfreeze() @@ -81,7 +82,7 @@ def test_unfreeze(): assert p.requires_grad is True -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_predict_tensor(): img = torch.rand(1, 3, 64, 64) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") @@ -93,7 +94,7 @@ def test_predict_tensor(): assert len(out[0][0][0]) == 64 -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_predict_numpy(): img = np.ones((1, 3, 64, 64)) model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") @@ -105,7 +106,8 @@ def test_predict_numpy(): assert len(out[0][0][0]) == 64 -@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="some serving") @mock.patch("flash._IS_TESTING", True) def test_serve(): model = SemanticSegmentation(2) @@ -113,6 +115,6 @@ def test_serve(): model.serve() -@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") def test_available_pretrained_weights(): assert SemanticSegmentation.available_pretrained_weights("resnet18") == ["imagenet", "ssl", "swsl"] diff --git a/tests/image/semantic_segm/test_output.py b/tests/image/semantic_segm/test_output.py index f431558e7b..df410ccdff 100644 --- a/tests/image/semantic_segm/test_output.py +++ b/tests/image/semantic_segm/test_output.py @@ -15,12 +15,19 @@ import torch from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE +from flash.core.utilities.imports import ( + _FIFTYONE_AVAILABLE, + _SEGMENTATION_MODELS_AVAILABLE, + _TOPIC_IMAGE_AVAILABLE, + _TOPIC_SERVE_AVAILABLE, +) from flash.image.segmentation.output import FiftyOneSegmentationLabelsOutput, SegmentationLabelsOutput +@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") +@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") +@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="some serving") class TestSemanticSegmentationLabelsOutput: - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") @staticmethod def test_smoke(): serial = SegmentationLabelsOutput() @@ -28,7 +35,6 @@ def test_smoke(): assert serial.labels_map is None assert serial.visualize is False - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") @staticmethod def test_exception(): serial = SegmentationLabelsOutput() @@ -41,7 +47,6 @@ def test_exception(): sample = torch.zeros(2, 3) serial.transform(sample) - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.") @staticmethod def test_serialize(): serial = SegmentationLabelsOutput() @@ -54,7 +59,6 @@ def test_serialize(): assert torch.tensor(classes)[1, 2] == 1 assert torch.tensor(classes)[0, 1] == 3 - @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") @staticmethod def test_serialize_fiftyone(): From 869c61e6482197c68bcf41916ef983a53308ec03 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Sun, 7 May 2023 00:08:14 +0200 Subject: [PATCH 22/39] bump versions & cleaning (#1487) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/general/finetuning.rst | 21 +------- requirements.txt | 6 +-- requirements/datatype_pointcloud.txt | 7 ++- requirements/datatype_tabular.txt | 2 +- src/flash/core/classification.py | 9 +--- src/flash/core/model.py | 12 ++--- src/flash/core/trainer.py | 44 +--------------- src/flash/core/utilities/imports.py | 5 -- src/flash/image/classification/adapters.py | 18 ++----- .../classification/integrations/baal/loop.py | 50 +++++-------------- src/flash/image/segmentation/model.py | 5 +- src/flash/pointcloud/segmentation/model.py | 6 +-- src/flash/text/question_answering/model.py | 12 +---- src/flash/text/seq2seq/summarization/model.py | 11 +--- src/flash/text/seq2seq/translation/model.py | 6 +-- tests/core/utilities/test_lightning_cli.py | 14 ++---- .../test_training_strategies.py | 12 ++--- 17 files changed, 45 insertions(+), 195 deletions(-) diff --git a/docs/source/general/finetuning.rst b/docs/source/general/finetuning.rst index 0854740d77..97a4d4465c 100644 --- a/docs/source/general/finetuning.rst +++ b/docs/source/general/finetuning.rst @@ -62,7 +62,7 @@ Finetune strategies ) model = ImageClassifier(backbone="resnet18", num_classes=2) - trainer = flash.Trainer(max_epochs=1, checkpoint_callback=False) + trainer = flash.Trainer(max_epochs=1) Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning. @@ -104,11 +104,6 @@ The freeze strategy keeps the backbone frozen throughout. trainer.finetune(model, datamodule, strategy="freeze") -.. testoutput:: strategies - :hide: - - ... - The pseudocode looks like: .. code-block:: python @@ -140,11 +135,6 @@ For example, to unfreeze after epoch 7: trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7)) -.. testoutput:: strategies - :hide: - - ... - Under the hood, the pseudocode looks like: .. code-block:: python @@ -180,10 +170,6 @@ Here's an example where: trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2))) -.. testoutput:: strategies - :hide: - - ... Under the hood, the pseudocode looks like: @@ -238,11 +224,6 @@ For even more customization, create your own finetuning callback. Learn more abo # Pass the callback to trainer.finetune trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5)) -.. testoutput:: strategies - :hide: - - ... - Working with DeepSpeed ====================== diff --git a/requirements.txt b/requirements.txt index be60c76589..3c5740315d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,12 +4,12 @@ packaging <23.0 setuptools <=59.5.0 # Prevent install bug with tensorboard numpy <1.24 # strict - freeze for using np.long torch >=1.7.1 -torchmetrics >0.5.1, <0.11.0 # strict -pytorch-lightning >=1.3.6, <1.9.0 # strict +torchmetrics >=0.7.0, <0.11.0 # strict +pytorch-lightning >=1.6.0, <1.9.0 # strict pyDeprecate pandas >=1.1.0, <=1.5.2 jsonargparse[signatures] >=3.17.0, <=4.9.0 click >=7.1.2, <=8.1.3 protobuf <=3.20.1 fsspec[http] >=2022.5.0,<=2022.7.1 -lightning-utilities >=0.3.0 +lightning-utilities >=0.4.1 diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt index 52666d7698..c3cc0d490f 100644 --- a/requirements/datatype_pointcloud.txt +++ b/requirements/datatype_pointcloud.txt @@ -1,7 +1,6 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup -open3d ==0.13 -# Version mismatch: Open3D needs PyTorch version 1.7.*, but version 1.8.1+cpu is installed! -torch >=1.7.0, <1.8.0 -torchvision >=0.8.0, <0.9.0 +open3d >=0.17.0, <0.18.0 +# torch >=1.8.0, <1.9.0 +# torchvision >0.9.0, <0.10.0 tensorboard <=2.11.0 diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index 30ca9c29d5..8782e376f6 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -1,7 +1,7 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup scikit-learn <=1.2.0 -pytorch-forecasting >=0.9.0, <=0.10.3 +pytorch-forecasting >=0.10.0, <=0.10.3 # pytorch-tabular >=1.0.2, <1.0.3 # pending requirements resolving https://github.com/manujosephv/pytorch_tabular/archive/refs/heads/main.zip torchmetrics >=0.10.0 diff --git a/src/flash/core/classification.py b/src/flash/core/classification.py index a236e1486d..1edab7ee76 100644 --- a/src/flash/core/classification.py +++ b/src/flash/core/classification.py @@ -17,14 +17,14 @@ import torch.nn.functional as F from pytorch_lightning.utilities import rank_zero_warn from torch import Tensor -from torchmetrics import Accuracy, Metric +from torchmetrics import Accuracy, F1Score, Metric from flash.core.adapter import AdapterTask from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TM_GREATER_EQUAL_0_7_0, lazy_import, requires +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires from flash.core.utilities.providers import _FIFTYONE if _FIFTYONE_AVAILABLE: @@ -36,11 +36,6 @@ Classification = None Classifications = None -if _TM_GREATER_EQUAL_0_7_0: - from torchmetrics import F1Score -else: - from torchmetrics import F1 as F1Score - CLASSIFICATION_OUTPUTS = FlashRegistry("outputs") diff --git a/src/flash/core/model.py b/src/flash/core/model.py index 1bc19203f1..5f9535c3e3 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -47,7 +47,7 @@ from flash.core.registry import FlashRegistry from flash.core.serve.composition import Composition from flash.core.utilities.apply_func import get_callable_dict -from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _TOPIC_CORE_AVAILABLE, requires +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE, requires from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import ( @@ -390,10 +390,6 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) - # PL 1.4.0 -> 1.4.9 tries to deepcopy the metric. - # Sometimes _forward_cache is not a leaf, so we convert it to one. - if not metric._forward_cache.is_leaf and not _PL_GREATER_EQUAL_1_5_0: - metric._forward_cache = metric._forward_cache.clone().detach() logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) @@ -433,7 +429,7 @@ def forward(self, x: Any) -> Any: def training_step(self, batch: Any, batch_idx: int) -> Any: output = self.step(batch, batch_idx, self.train_metrics) - log_kwargs = {"batch_size": output.get(OutputKeys.BATCH_SIZE, None)} if _PL_GREATER_EQUAL_1_5_0 else {} + log_kwargs = {"batch_size": output.get(OutputKeys.BATCH_SIZE, None)} self.log_dict( {f"train_{k}": v for k, v in output[OutputKeys.LOGS].items()}, on_step=True, @@ -445,7 +441,7 @@ def training_step(self, batch: Any, batch_idx: int) -> Any: def validation_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx, self.val_metrics) - log_kwargs = {"batch_size": output.get(OutputKeys.BATCH_SIZE, None)} if _PL_GREATER_EQUAL_1_5_0 else {} + log_kwargs = {"batch_size": output.get(OutputKeys.BATCH_SIZE, None)} self.log_dict( {f"val_{k}": v for k, v in output[OutputKeys.LOGS].items()}, on_step=False, @@ -456,7 +452,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None: def test_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx, self.test_metrics) - log_kwargs = {"batch_size": output.get(OutputKeys.BATCH_SIZE, None)} if _PL_GREATER_EQUAL_1_5_0 else {} + log_kwargs = {"batch_size": output.get(OutputKeys.BATCH_SIZE, None)} self.log_dict( {f"test_{k}": v for k, v in output[OutputKeys.LOGS].items()}, on_step=False, diff --git a/src/flash/core/trainer.py b/src/flash/core/trainer.py index 6a9ab8b94a..fa40cb7f4f 100644 --- a/src/flash/core/trainer.py +++ b/src/flash/core/trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -import math import warnings from argparse import ArgumentParser, Namespace from functools import wraps @@ -23,7 +22,6 @@ from pytorch_lightning import Trainer as PlTrainer from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.callbacks import BaseFinetuning -from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables from torch.utils.data import DataLoader @@ -33,7 +31,6 @@ from flash.core.data.io.transform_predictions import TransformPredictions from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_0, _PL_GREATER_EQUAL_1_5_0, _PL_GREATER_EQUAL_1_6_0 def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -260,43 +257,4 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] """ - if _PL_GREATER_EQUAL_1_6_0: - return super().estimated_stepping_batches - # Copied from PL 1.6 - accumulation_scheduler = self.accumulation_scheduler - - if accumulation_scheduler.epochs != [0]: - raise ValueError( - "Estimated stepping batches cannot be computed with different" - " `accumulate_grad_batches` at different epochs." - ) - - # infinite training - if self.max_epochs == -1 and self.max_steps == -1: - return float("inf") - - if self.train_dataloader is None: - rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.") - if _PL_GREATER_EQUAL_1_5_0: - self.reset_train_dataloader() - else: - self.reset_train_dataloader(self.lightning_module) - - total_batches = self.num_training_batches - - # iterable dataset - if total_batches == float("inf"): - return self.max_steps - - if _PL_GREATER_EQUAL_1_4_0: - self.accumulate_grad_batches = accumulation_scheduler.get_accumulate_grad_batches(self.current_epoch) - else: - # Call the callback hook manually to guarantee that `self.accumulate_grad_batches` has been set - accumulation_scheduler.on_train_epoch_start(self, self.lightning_module) - effective_batch_size = self.accumulate_grad_batches - max_estimated_steps = math.ceil(total_batches / effective_batch_size) * max(self.max_epochs, 1) - - max_estimated_steps = ( - min(max_estimated_steps, self.max_steps) if self.max_steps not in [None, -1] else max_estimated_steps - ) - return max_estimated_steps + return super().estimated_stepping_batches diff --git a/src/flash/core/utilities/imports.py b/src/flash/core/utilities/imports.py index 24f787899f..7c3bb75ef8 100644 --- a/src/flash/core/utilities/imports.py +++ b/src/flash/core/utilities/imports.py @@ -83,14 +83,9 @@ class Image: if Version: _TORCHVISION_GREATER_EQUAL_0_9 = compare_version("torchvision", operator.ge, "0.9.0") - _PL_GREATER_EQUAL_1_4_3 = compare_version("pytorch_lightning", operator.ge, "1.4.3") - _PL_GREATER_EQUAL_1_4_0 = compare_version("pytorch_lightning", operator.ge, "1.4.0") - _PL_GREATER_EQUAL_1_5_0 = compare_version("pytorch_lightning", operator.ge, "1.5.0") - _PL_GREATER_EQUAL_1_6_0 = compare_version("pytorch_lightning", operator.ge, "1.6.0rc0") _PL_GREATER_EQUAL_1_8_0 = compare_version("pytorch_lightning", operator.ge, "1.8.0") _PANDAS_GREATER_EQUAL_1_3_0 = compare_version("pandas", operator.ge, "1.3.0") _ICEVISION_GREATER_EQUAL_0_11_0 = compare_version("icevision", operator.ge, "0.11.0") - _TM_GREATER_EQUAL_0_7_0 = compare_version("torchmetrics", operator.ge, "0.7.0") _TM_GREATER_EQUAL_0_10_0 = compare_version("torchmetrics", operator.ge, "0.10.0") _BAAL_GREATER_EQUAL_1_5_2 = compare_version("baal", operator.ge, "1.5.2") diff --git a/src/flash/image/classification/adapters.py b/src/flash/image/classification/adapters.py index cc1cd0316f..796008472c 100644 --- a/src/flash/image/classification/adapters.py +++ b/src/flash/image/classification/adapters.py @@ -21,6 +21,7 @@ import torch from lightning_utilities.core.rank_zero import WarningCache from pytorch_lightning import LightningModule +from pytorch_lightning.strategies import DataParallelStrategy, DDPSpawnStrategy, DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from torch import Tensor, nn from torch.utils.data import DataLoader, IterableDataset, Sampler @@ -32,17 +33,12 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector -from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE, _PL_GREATER_EQUAL_1_6_0 +from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE from flash.core.utilities.providers import _LEARN2LEARN from flash.core.utilities.stability import beta from flash.core.utilities.url_error import catch_url_error from flash.image.classification.integrations.learn2learn import TaskDataParallel, TaskDistributedDataParallel -if _PL_GREATER_EQUAL_1_6_0: - from pytorch_lightning.strategies import DataParallelStrategy, DDPSpawnStrategy, DDPStrategy -else: - from pytorch_lightning.plugins import DataParallelPlugin, DDPPlugin, DDPSpawnPlugin - warning_cache = WarningCache() @@ -241,10 +237,7 @@ def _convert_dataset( task_collate=self._identity_task_collate_fn, ) - if _PL_GREATER_EQUAL_1_6_0: - is_ddp_or_ddp_spawn = isinstance(trainer.strategy, (DDPStrategy, DDPSpawnStrategy)) - else: - is_ddp_or_ddp_spawn = isinstance(trainer.training_type_plugin, (DDPPlugin, DDPSpawnPlugin)) + is_ddp_or_ddp_spawn = isinstance(trainer.strategy, (DDPStrategy, DDPSpawnStrategy)) if is_ddp_or_ddp_spawn: # when running in a distributed data parallel way, # we are actually sampling one task per device. @@ -260,10 +253,7 @@ def _convert_dataset( self.trainer.accumulated_grad_batches = self.meta_batch_size / trainer.world_size else: devices = 1 - if _PL_GREATER_EQUAL_1_6_0: - is_data_parallel = isinstance(trainer.strategy, DataParallelStrategy) - else: - is_data_parallel = isinstance(trainer.training_type_plugin, DataParallelPlugin) + is_data_parallel = isinstance(trainer.strategy, DataParallelStrategy) if is_data_parallel: # when using DP, we need to sample n tasks, so it can split across multiple devices. devices = accelerator_connector(trainer).devices diff --git a/src/flash/image/classification/integrations/baal/loop.py b/src/flash/image/classification/integrations/baal/loop.py index bb10a6c9cd..451d364309 100644 --- a/src/flash/image/classification/integrations/baal/loop.py +++ b/src/flash/image/classification/integrations/baal/loop.py @@ -15,43 +15,29 @@ from typing import Any, Dict, Optional from pytorch_lightning import LightningModule +from pytorch_lightning.loops import Loop +from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus from pytorch_lightning.utilities.model_helpers import is_overridden from torch import Tensor import flash from flash.core.data.utils import _STAGES_PREFIX -from flash.core.utilities.imports import ( - _PL_GREATER_EQUAL_1_4_0, - _PL_GREATER_EQUAL_1_5_0, - _PL_GREATER_EQUAL_1_6_0, - requires, -) +from flash.core.utilities.imports import requires from flash.core.utilities.stability import beta from flash.core.utilities.stages import RunningStage from flash.image.classification.integrations.baal.data import ActiveLearningDataModule from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask -if _PL_GREATER_EQUAL_1_4_0: - from pytorch_lightning.loops import Loop - from pytorch_lightning.loops.fit_loop import FitLoop - from pytorch_lightning.trainer.progress import Progress -else: - Loop = object - FitLoop = object - -if not _PL_GREATER_EQUAL_1_5_0: - from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader -else: - from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource - @beta("The BaaL integration is currently in Beta.") class ActiveLearningLoop(Loop): max_epochs: int inference_model: InferenceMCDropoutTask - @requires("baal", (_PL_GREATER_EQUAL_1_4_0, "pytorch-lightning>=1.4.0")) + @requires("baal") def __init__(self, label_epoch_frequency: int, inference_iteration: int = 2, should_reset_weights: bool = True): """The `ActiveLearning Loop` describes the following training procedure. This loop is connected with the `ActiveLearningTrainer` @@ -160,12 +146,7 @@ def __getattr__(self, key): return self.__dict__[key] def _connect(self, model: LightningModule): - if _PL_GREATER_EQUAL_1_6_0: - self.trainer.strategy.connect(model) - elif _PL_GREATER_EQUAL_1_5_0: - self.trainer.training_type_plugin.connect(model) - else: - self.trainer.accelerator.connect(model) + self.trainer.strategy.connect(model) def _reset_fitting(self): self.trainer.state.fn = TrainerFn.FITTING @@ -194,18 +175,11 @@ def _reset_dataloader_for_stage(self, running_state: RunningStage): ) if dataloader: - if _PL_GREATER_EQUAL_1_5_0: - setattr( - self.trainer._data_connector, - f"_{dataloader_name}_source", - _DataLoaderSource(self.trainer.datamodule, dataloader_name), - ) - else: - setattr( - self.trainer.lightning_module, - dataloader_name, - _PatchDataLoader(dataloader(), running_state), - ) + setattr( + self.trainer._data_connector, + f"_{dataloader_name}_source", + _DataLoaderSource(self.trainer.datamodule, dataloader_name), + ) setattr(self.trainer, dataloader_name, None) # TODO: Resolve this within PyTorch Lightning. try: diff --git a/src/flash/image/segmentation/model.py b/src/flash/image/segmentation/model.py index f07cb11054..121150e60c 100644 --- a/src/flash/image/segmentation/model.py +++ b/src/flash/image/segmentation/model.py @@ -24,7 +24,6 @@ from flash.core.registry import FlashRegistry from flash.core.serve import Composition from flash.core.utilities.imports import ( - _TM_GREATER_EQUAL_0_7_0, _TM_GREATER_EQUAL_0_10_0, _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_9, @@ -58,10 +57,8 @@ class InterpolationMode: if _TM_GREATER_EQUAL_0_10_0: from torchmetrics.classification import MulticlassJaccardIndex as JaccardIndex -elif _TM_GREATER_EQUAL_0_7_0: - from torchmetrics import JaccardIndex else: - from torchmetrics import IoU as JaccardIndex + from torchmetrics import JaccardIndex class SemanticSegmentationOutputTransform(OutputTransform): diff --git a/src/flash/pointcloud/segmentation/model.py b/src/flash/pointcloud/segmentation/model.py index 79363267a3..9d38c9044c 100644 --- a/src/flash/pointcloud/segmentation/model.py +++ b/src/flash/pointcloud/segmentation/model.py @@ -24,7 +24,7 @@ from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.collate import wrap_collate from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_7_0, _TM_GREATER_EQUAL_0_10_0, _TOPIC_POINTCLOUD_AVAILABLE +from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_10_0, _TOPIC_POINTCLOUD_AVAILABLE from flash.core.utilities.stability import beta from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES @@ -35,10 +35,8 @@ if _TM_GREATER_EQUAL_0_10_0: from torchmetrics.classification import MulticlassJaccardIndex as JaccardIndex -elif _TM_GREATER_EQUAL_0_7_0: - from torchmetrics import JaccardIndex else: - from torchmetrics import IoU as JaccardIndex + from torchmetrics import JaccardIndex @beta("Point cloud segmentation is currently in Beta.") diff --git a/src/flash/text/question_answering/model.py b/src/flash/text/question_answering/model.py index 478c2778a9..e999825777 100644 --- a/src/flash/text/question_answering/model.py +++ b/src/flash/text/question_answering/model.py @@ -32,7 +32,7 @@ from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry -from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_7_0, _TOPIC_TEXT_AVAILABLE +from flash.core.utilities.imports import _TOPIC_TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.ort_callback import ORTCallback @@ -144,15 +144,7 @@ def __init__( self.null_score_diff_threshold = null_score_diff_threshold self._initialize_model_specific_parameters() - if _TM_GREATER_EQUAL_0_7_0: - self.rouge = ROUGEScore( - use_stemmer=use_stemmer, - ) - else: - self.rouge = ROUGEScore( - True, - use_stemmer=use_stemmer, - ) + self.rouge = ROUGEScore(use_stemmer=use_stemmer) def _generate_answers(self, pred_start_logits, pred_end_logits, examples): all_predictions = collections.OrderedDict() diff --git a/src/flash/text/seq2seq/summarization/model.py b/src/flash/text/seq2seq/summarization/model.py index 5c2fab7947..926a9823b6 100644 --- a/src/flash/text/seq2seq/summarization/model.py +++ b/src/flash/text/seq2seq/summarization/model.py @@ -16,7 +16,6 @@ from torch import Tensor from torchmetrics.text.rouge import ROUGEScore -from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_7_0 from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.seq2seq.core.model import Seq2SeqTask @@ -76,15 +75,7 @@ def __init__( num_beams=num_beams, enable_ort=enable_ort, ) - if _TM_GREATER_EQUAL_0_7_0: - self.rouge = ROUGEScore( - use_stemmer=use_stemmer, - ) - else: - self.rouge = ROUGEScore( - True, - use_stemmer=use_stemmer, - ) + self.rouge = ROUGEScore(use_stemmer=use_stemmer) @property def task(self) -> str: diff --git a/src/flash/text/seq2seq/translation/model.py b/src/flash/text/seq2seq/translation/model.py index d6365d3864..2be597e052 100644 --- a/src/flash/text/seq2seq/translation/model.py +++ b/src/flash/text/seq2seq/translation/model.py @@ -15,7 +15,6 @@ from torchmetrics import BLEUScore -from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_7_0 from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.seq2seq.core.model import Seq2SeqTask @@ -94,8 +93,5 @@ def compute_metrics(self, generated_tokens, batch, prefix): translate_corpus = self.decode(generated_tokens) translate_corpus = [line for line in translate_corpus] - if _TM_GREATER_EQUAL_0_7_0: - result = self.bleu(translate_corpus, reference_corpus) - else: - result = self.bleu(reference_corpus, translate_corpus) + result = self.bleu(translate_corpus, reference_corpus) self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True) diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index f4a4c2493a..f2d4dda210 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -21,12 +21,7 @@ from torch import nn from flash.core.utilities.compatibility import accelerator_connector -from flash.core.utilities.imports import ( - _PL_GREATER_EQUAL_1_4_0, - _PL_GREATER_EQUAL_1_6_0, - _TOPIC_CORE_AVAILABLE, - _TORCHVISION_AVAILABLE, -) +from flash.core.utilities.imports import _TOPIC_CORE_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.lightning_cli import ( LightningArgumentParser, LightningCLI, @@ -163,7 +158,7 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): ({}, {}), (dict(logger=False), {}), (dict(logger=False), dict(logger=True)), - (dict(logger=False), dict(checkpoint_callback=True)), + (dict(logger=False), dict(enable_checkpointing=True)), ], ) def test_init_from_argparse_args(cli_args, extra_args): @@ -278,7 +273,7 @@ def add_arguments_to_parser(self, parser): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.skipif(_PL_GREATER_EQUAL_1_6_0, reason="Bugs in PL >= 1.6.0") +@pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] @@ -588,8 +583,7 @@ def on_exception(self, execption): dict(accelerator="cpu", strategy="ddp", plugins="ddp_find_unused_parameters_false"), ), ) -@pytest.mark.skipif(not _PL_GREATER_EQUAL_1_4_0, reason="Bugs in PL < 1.4.0") -@pytest.mark.skipif(_PL_GREATER_EQUAL_1_6_0, reason="Bugs in PL >= 1.6.0") +@pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): with mock.patch("sys.argv", ["any.py"]), pytest.raises(CustomException): LightningCLI( diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py index 99231a016d..cd66fb2673 100644 --- a/tests/image/classification/test_training_strategies.py +++ b/tests/image/classification/test_training_strategies.py @@ -20,7 +20,7 @@ from flash import Trainer from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE, _PL_GREATER_EQUAL_1_6_0, _TOPIC_IMAGE_AVAILABLE +from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE, _TOPIC_IMAGE_AVAILABLE from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.adapters import TRAINING_STRATEGIES from tests.image.classification.test_data import _rand_image @@ -87,10 +87,7 @@ def _test_learn2learning_training_strategies(gpus, training_strategy, tmpdir, ac training_strategy_kwargs={"ways": dm.num_classes, "shots": 4, "meta_batch_size": 4}, ) - if _PL_GREATER_EQUAL_1_6_0: - trainer = Trainer(fast_dev_run=2, gpus=gpus, strategy=strategy) - else: - trainer = Trainer(fast_dev_run=2, gpus=gpus, accelerator=accelerator) + trainer = Trainer(fast_dev_run=2, gpus=gpus, strategy=strategy) trainer.fit(model, datamodule=dm) @@ -115,7 +112,4 @@ def test_wrongly_specified_training_strategies(): @pytest.mark.skipif(not os.getenv("FLASH_RUNNING_SPECIAL_TESTS", "0") == "1", reason="Should run with special test") @pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") def test_learn2learn_training_strategies_ddp(tmpdir): - if _PL_GREATER_EQUAL_1_6_0: - _test_learn2learning_training_strategies(2, "prototypicalnetworks", tmpdir, strategy="ddp") - else: - _test_learn2learning_training_strategies(2, "prototypicalnetworks", tmpdir, accelerator="ddp") + _test_learn2learning_training_strategies(2, "prototypicalnetworks", tmpdir, strategy="ddp") From 5e4b1449dcb84fa21b4964a830364bff4fc29b51 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 10 May 2023 15:10:24 +0200 Subject: [PATCH 23/39] pkg req. & resolve GPU testing (#1547) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .azure/gpu-special-tests.yml | 13 +- .azure/template-examples.yml | 48 +++---- requirements/datatype_tabular.txt | 2 +- requirements/devel.txt | 13 -- setup.py | 7 +- src/flash/core/classification.py | 6 +- .../integrations/pytorch_tabular/backbones.py | 12 +- tests/core/utilities/test_lightning_cli.py | 123 +++++++++--------- tests/special_tests.sh | 2 +- 9 files changed, 111 insertions(+), 115 deletions(-) delete mode 100644 requirements/devel.txt diff --git a/.azure/gpu-special-tests.yml b/.azure/gpu-special-tests.yml index 5b3060ac6a..f7c26da082 100644 --- a/.azure/gpu-special-tests.yml +++ b/.azure/gpu-special-tests.yml @@ -19,18 +19,15 @@ jobs: timeoutInMinutes: "45" # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" - pool: "lit-rtx-3090" variables: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) - container: - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12-cuda11.6.1" + # image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11" + image: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime" options: "--ipc=host --gpus=all" - workspace: clean: all - steps: - bash: echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" @@ -52,8 +49,7 @@ jobs: displayName: 'Sanity check' - bash: | - # python -m pip install "pip==20.1" - pip install '.[image,test]' -r requirements/testing_image.txt + pip install '.[image,test]' -r requirements/testing_image.txt -U pip list env: FREEZE_REQUIREMENTS: 1 @@ -66,7 +62,6 @@ jobs: - bash: | python -m coverage report python -m coverage xml - python -m coverage html - python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure + # python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure ls -l displayName: 'Statistics' diff --git a/.azure/template-examples.yml b/.azure/template-examples.yml index ab9445c170..ce88118ba6 100644 --- a/.azure/template-examples.yml +++ b/.azure/template-examples.yml @@ -1,7 +1,7 @@ jobs: - - ${{ each dom in parameters.domains }}: + - ${{ each topic in parameters.domains }}: - job: - displayName: "domain ${{dom}} with 2 GPU" + displayName: "domain ${{topic}} with 2 GPU" # how long to run the job before automatically cancelling timeoutInMinutes: "45" # how much time to give 'run always even if cancelled tasks' before stopping them @@ -14,16 +14,23 @@ jobs: # this need to have installed docker in the base image... container: # base ML image: mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04 - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11" - # image: "pytorch/pytorch:1.8.1-cuda11.0-cudnn8-runtime" - options: "-it --rm --gpus=all --shm-size=16g" + # image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11" + image: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime" + options: "-it --rm --gpus=all --shm-size=16g -v /usr/bin/docker:/tmp/docker:ro" workspace: clean: all steps: - - bash: echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" - displayName: 'set visible devices' + - bash: | + echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" + echo "##vso[task.setvariable variable=CONTAINER_ID]$(head -1 /proc/self/cgroup|cut -d/ -f3)" + displayName: 'Set environment variables' + + - script: | + /tmp/docker exec -t -u 0 $CONTAINER_ID \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -o Dpkg::Options::="--force-confold" -y install sudo" + displayName: 'Install Sudo in container (thanks Microsoft!)' - bash: | echo $CUDA_VISIBLE_DEVICES @@ -35,38 +42,31 @@ jobs: df -kh /dev/shm displayName: 'Image info & NVIDIA' - - bash: | + - script: | python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" displayName: 'Sanity check' - - bash: | - pip install '.[${{dom}},test]' -r requirements/testing_${{dom}}.txt - pip list + - script: | + sudo apt-get install -y build-essential gcc cmake software-properties-common + python -m pip install "pip==22.2.1" + pip --version + pip install '.[${{topic}},test]' -r "requirements/testing_${{topic}}.txt" -U --prefer-binary env: FREEZE_REQUIREMENTS: 1 displayName: 'Install dependencies' - - bash: | - python -m coverage run --source flash -m pytest --durations=30 - env: - FLASH_TEST_TOPIC: ${{ dom }} + - script: | + pip list + python -m coverage run --source flash -m pytest tests/examples -vV --durations=30 displayName: 'Testing' - bash: | python -m coverage report python -m coverage xml - python -m coverage html - python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure + # python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure ls -l displayName: 'Statistics' - - task: PublishTestResults@2 - displayName: 'Publish test results' - inputs: - testResultsFiles: '$(Build.StagingDirectory)/test-results.xml' - testRunTitle: '$(Agent.OS) - $(Build.DefinitionName) - Python $(python.version)' - condition: succeededOrFailed() - - task: PublishCodeCoverageResults@1 displayName: 'Publish coverage report' inputs: diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index 8782e376f6..fc9eeb7775 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -3,6 +3,6 @@ scikit-learn <=1.2.0 pytorch-forecasting >=0.10.0, <=0.10.3 # pytorch-tabular >=1.0.2, <1.0.3 # pending requirements resolving -https://github.com/manujosephv/pytorch_tabular/archive/refs/heads/main.zip +pytorch-tabular @ https://github.com/manujosephv/pytorch_tabular/archive/refs/heads/main.zip torchmetrics >=0.10.0 omegaconf <=2.1.1, <=2.1.1 diff --git a/requirements/devel.txt b/requirements/devel.txt deleted file mode 100644 index 87b5d71632..0000000000 --- a/requirements/devel.txt +++ /dev/null @@ -1,13 +0,0 @@ --r ../requirements.txt - --r ./test.txt - --r ./docs.txt - --r ./datatype_image.txt - --r ./datatype_tabular.txt - --r ./datatype_text.txt - --r ./datatype_video.txt diff --git a/setup.py b/setup.py index 87dc5d3727..88267df0c9 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True is_strict = False req = ln.strip() # skip directly installed dependencies - if not req or any(c in req for c in ["http:", "https:", "@"]): + if not req or (unfreeze and any(c in req for c in ["http:", "https:", "@"])): return "" # remove version restrictions unless they are strict @@ -102,8 +102,9 @@ def _load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: boo lines = [ln.strip() for ln in file.readlines()] reqs = [_augment_requirement(ln, unfreeze=unfreeze) for ln in lines] reqs = [str(req) for req in reqs if req and not req.startswith("-r")] - # filter empty lines and containing @ which means redirect to some git/http - reqs = [req for req in reqs if not any(c in req for c in ["@", "http://", "https://"])] + if unfreeze: + # filter empty lines and containing @ which means redirect to some git/http + reqs = [req for req in reqs if not any(c in req for c in ["@", "http://", "https://"])] return reqs diff --git a/src/flash/core/classification.py b/src/flash/core/classification.py index 1edab7ee76..25973d79f1 100644 --- a/src/flash/core/classification.py +++ b/src/flash/core/classification.py @@ -59,7 +59,11 @@ def _build( self.labels = labels if metrics is None: - metrics = F1Score(num_classes) if (multi_label and num_classes) else Accuracy() + metrics = ( + F1Score(num_classes=num_classes, task="multilabel", top_k=1) + if (multi_label and num_classes) + else Accuracy() + ) if loss_fn is None: loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy diff --git a/src/flash/core/integrations/pytorch_tabular/backbones.py b/src/flash/core/integrations/pytorch_tabular/backbones.py index b3118fd9b8..6f2ca58d24 100644 --- a/src/flash/core/integrations/pytorch_tabular/backbones.py +++ b/src/flash/core/integrations/pytorch_tabular/backbones.py @@ -23,7 +23,7 @@ from flash.core.utilities.providers import _PYTORCH_TABULAR if _PYTORCHTABULAR_AVAILABLE: - import pytorch_tabular.models as models + import pytorch_tabular from omegaconf import DictConfig, OmegaConf from pytorch_tabular.config import ModelConfig from pytorch_tabular.models import ( @@ -46,7 +46,7 @@ def _read_parse_config(config, cls): if os.path.exists(config): _config = OmegaConf.load(config) if cls == ModelConfig: - cls = getattr(getattr(models, _config._module_src), _config._config_name) + cls = getattr(getattr(pytorch_tabular, _config._module_src), _config._config_name) config = cls( **{ k: v @@ -69,12 +69,16 @@ def load_pytorch_tabular( ): model_config = model_config_class(task=task_type, embedding_dims=parameters["embedding_dims"], **model_kwargs) model_config = _read_parse_config(model_config, ModelConfig) - model_callable = getattr(getattr(models, model_config._module_src), model_config._model_name) + model_callable = pytorch_tabular + for attr in model_config._module_src.split(".") + [model_config._model_name]: + model_callable = getattr(model_callable, attr) config = OmegaConf.merge( OmegaConf.create(parameters), OmegaConf.to_container(model_config), ) - model = model_callable(config=config, custom_loss=loss_fn, custom_metrics=metrics) + model = model_callable( + config=config, custom_loss=loss_fn, custom_metrics=metrics, inferred_config=DictConfig(config) + ) return model for model_config_class, name in zip( diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index f2d4dda210..e40d3e946f 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -52,7 +52,7 @@ def test_default_args(mock_argparse, tmpdir): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) +@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--default_root_dir=./"], []]) def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) @@ -87,7 +87,6 @@ def test_add_argparse_args_redefined(cli_args): ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)), ("--limit_train_batches=100", dict(limit_train_batches=100)), ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)), - ("--weights_summary=null", dict(weights_summary=None)), ], ) def test_parse_args_parsing(cli_args, expected): @@ -134,9 +133,9 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): ("--gpus 0,1", [0, 1]), ], ) -def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): +def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" - monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + mocker.patch("lightning_lite.utilities.device_parser._get_all_available_gpus", return_value=[0, 1]) cli_args = cli_args.split(" ") if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) @@ -144,7 +143,7 @@ def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): args = parser.parse_args() trainer = Trainer.from_argparse_args(args) - assert trainer.data_parallel_device_ids == expected_gpu + assert trainer.device_ids == expected_gpu @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @@ -225,6 +224,19 @@ def on_train_start(callback, trainer, _): assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts +class TestModelCallbacks(BoringModel): + def on_fit_start(self): + callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == "epoch" + assert callback[0].log_momentum is True + + callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + assert len(callback) == 1 + assert callback[0].monitor == "NAME" + self.trainer.ran_asserts = True + + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ @@ -235,19 +247,8 @@ def test_lightning_cli_args_callbacks(tmpdir): dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")), ] - class TestModel(BoringModel): - def on_fit_start(self): - callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] - assert len(callback) == 1 - assert callback[0].logging_interval == "epoch" - assert callback[0].log_momentum is True - callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - assert len(callback) == 1 - assert callback[0].monitor == "NAME" - self.trainer.ran_asserts = True - with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]): - cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + cli = LightningCLI(TestModelCallbacks, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts @@ -272,19 +273,20 @@ def add_arguments_to_parser(self, parser): assert callback[0].logging_interval == "epoch" +class TestModelClusterEnv(BoringModel): + def on_fit_start(self): + # Ensure SLURMEnvironment is set, instead of default LightningEnvironment + assert isinstance(accelerator_connector(self.trainer)._cluster_environment, SLURMEnvironment) + self.trainer.ran_asserts = True + + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] - class TestModel(BoringModel): - def on_fit_start(self): - # Ensure SLURMEnvironment is set, instead of default LightningEnvironment - assert isinstance(accelerator_connector(self.trainer)._cluster_environment, SLURMEnvironment) - self.trainer.ran_asserts = True - with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): - cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + cli = LightningCLI(TestModelClusterEnv, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts @@ -421,19 +423,20 @@ def test_lightning_cli_print_config(): assert outval["data"]["class_path"] == "tests.helpers.boring_model.BoringDataModule" +class MainModule(BoringModel): + def __init__( + self, + submodule1: LightningModule, + submodule2: LightningModule, + main_param: int = 1, + ): + super().__init__() + self.submodule1 = submodule1 + self.submodule2 = submodule2 + + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_submodules(tmpdir): - class MainModule(BoringModel): - def __init__( - self, - submodule1: LightningModule, - submodule2: LightningModule, - main_param: int = 1, - ): - super().__init__() - self.submodule1 = submodule1 - self.submodule2 = submodule2 - config = """model: main_param: 2 submodule1: @@ -459,19 +462,20 @@ def __init__( assert isinstance(cli.model.submodule2, BoringModel) +class TestModuleTorch(BoringModel): + def __init__( + self, + activation: nn.Module = None, + transform: Optional[List[nn.Module]] = None, + ): + super().__init__() + self.activation = activation + self.transform = transform + + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required") def test_lightning_cli_torch_modules(tmpdir): - class TestModule(BoringModel): - def __init__( - self, - activation: nn.Module = None, - transform: Optional[List[nn.Module]] = None, - ): - super().__init__() - self.activation = activation - self.transform = transform - config = """model: activation: class_path: torch.nn.LeakyReLU @@ -496,7 +500,7 @@ def __init__( ] with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = LightningCLI(TestModule) + cli = LightningCLI(TestModuleTorch) assert isinstance(cli.model.activation, torch.nn.LeakyReLU) assert cli.model.activation.negative_slope == 0.2 @@ -698,6 +702,19 @@ def add_arguments_to_parser(self, parser): assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 +class TestModelOptLR(BoringModel): + def __init__( + self, + optim1: dict, + optim2: dict, + scheduler: dict, + ): + super().__init__() + self.optim1 = instantiate_class(self.parameters(), optim1) + self.optim2 = instantiate_class(self.parameters(), optim2) + self.scheduler = instantiate_class(self.optim1, scheduler) + + @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): class MyLightningCLI(LightningCLI): @@ -706,18 +723,6 @@ def add_arguments_to_parser(self, parser): parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") - class TestModel(BoringModel): - def __init__( - self, - optim1: dict, - optim2: dict, - scheduler: dict, - ): - super().__init__() - self.optim1 = instantiate_class(self.parameters(), optim1) - self.optim2 = instantiate_class(self.parameters(), optim2) - self.scheduler = instantiate_class(self.optim1, scheduler) - cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", @@ -727,7 +732,7 @@ def __init__( ] with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = MyLightningCLI(TestModel) + cli = MyLightningCLI(TestModelOptLR) assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 99cac8929a..96a060182c 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -31,7 +31,7 @@ linenos=$(echo "$grep_output" | cut -f2 -d:) linenos_arr=($linenos) # tests to skip - space separated -blocklist='test_pytorch_profiler_nested_emit_nvtx' +blocklist='' report='' for i in "${!files_arr[@]}"; do From d93102ac5f107732000e8a1430c8358be7397dbb Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 10 May 2023 22:07:58 +0200 Subject: [PATCH 24/39] drop unused TPU tests --- dockers/tpu-tests/Dockerfile | 33 ---------- dockers/tpu-tests/docker-entrypoint.sh | 8 --- dockers/tpu-tests/tpu_test_cases.jsonnet | 49 --------------- tests/tpu/__init__.py | 0 tests/tpu/test_sample_tpu.py | 18 ------ tests/tpu/test_tpu_multi_core.py | 55 ----------------- tests/tpu/test_tpu_single_core.py | 64 -------------------- tests/tpu_tests.sh | 77 ------------------------ 8 files changed, 304 deletions(-) delete mode 100644 dockers/tpu-tests/Dockerfile delete mode 100644 dockers/tpu-tests/docker-entrypoint.sh delete mode 100644 dockers/tpu-tests/tpu_test_cases.jsonnet delete mode 100644 tests/tpu/__init__.py delete mode 100644 tests/tpu/test_sample_tpu.py delete mode 100644 tests/tpu/test_tpu_multi_core.py delete mode 100644 tests/tpu/test_tpu_single_core.py delete mode 100755 tests/tpu_tests.sh diff --git a/dockers/tpu-tests/Dockerfile b/dockers/tpu-tests/Dockerfile deleted file mode 100644 index 9a1d5ac59f..0000000000 --- a/dockers/tpu-tests/Dockerfile +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -ARG PYTHON_VERSION=3.9 -ARG PYTORCH_VERSION=1.9 - -FROM pytorchlightning/pytorch_lightning:base-xla-py${PYTHON_VERSION}-torch${PYTORCH_VERSION} - -LABEL maintainer="Lightning-AI " - -COPY ./ ./lightning-flash/ - -RUN \ - pip install -q fire && \ - # drop unnecessary packages - pip install -r lightning-flash/requirements.txt --no-cache-dir - -COPY ./dockers/tpu-tests/docker-entrypoint.sh /usr/local/bin/ -RUN chmod +x /usr/local/bin/docker-entrypoint.sh - -ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"] -CMD ["bash"] diff --git a/dockers/tpu-tests/docker-entrypoint.sh b/dockers/tpu-tests/docker-entrypoint.sh deleted file mode 100644 index 57abc703c8..0000000000 --- a/dockers/tpu-tests/docker-entrypoint.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -# source ~/.bashrc -echo "running docker-entrypoint.sh" -# conda activate container -echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS -echo "printed TPU info" -export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" -exec "$@" diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet deleted file mode 100644 index 5561424c76..0000000000 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ /dev/null @@ -1,49 +0,0 @@ -local base = import 'templates/base.libsonnet'; -local tpus = import 'templates/tpus.libsonnet'; -local utils = import "templates/utils.libsonnet"; - -local tputests = base.BaseTest { - frameworkPrefix: 'pl', - modelName: 'tpu-tests', - mode: 'postsubmit', - configMaps: [], - - timeout: 6000, # 100 minutes, in seconds. - - image: 'pytorchlightning/pytorch_lightning', - imageTag: 'base-xla-py{PYTHON_VERSION}-torch{PYTORCH_VERSION}', - - tpuSettings+: { - softwareVersion: 'pytorch-{PYTORCH_VERSION}', - }, - accelerator: tpus.v3_8, - - command: utils.scriptCommand( - ||| - source ~/.bashrc - conda activate lightning - mkdir -p /home/runner/work/lightning-flash && cd /home/runner/work/lightning-flash - git clone https://github.com/Lightning-AI/lightning-flash.git - cd lightning-flash - echo $PWD - git ls-remote --refs origin - git fetch origin "refs/pull/{PR_NUMBER}/head:pr/{PR_NUMBER}" && git checkout "pr/{PR_NUMBER}" - git checkout {SHA} - export FREEZE_REQUIREMENTS=1 - pip install -e .[test] - echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS - export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}" - cd tests - coverage run --source=lightning_flash -m pytest -vv --durations=0 ./ - echo "\n||| Running TPU Tests |||\n" - bash tpu_tests.sh - test_exit_code=$? - echo "\n||| END PYTEST LOGS |||\n" - coverage xml - cat coverage.xml | tr -d '\t' - test $test_exit_code -eq 0 - ||| - ), -}; - -tputests.oneshotJob diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/tpu/test_sample_tpu.py b/tests/tpu/test_sample_tpu.py deleted file mode 100644 index 78377f4a49..0000000000 --- a/tests/tpu/test_sample_tpu.py +++ /dev/null @@ -1,18 +0,0 @@ -import os - -import pytest -from pytorch_lightning.accelerators.tpu import TPUAccelerator - -from flash import Trainer - - -@pytest.mark.skipif(not os.getenv("FLASH_RUN_TPU_TESTS", "0") == "1", reason="Should run with tpu test") -def test_tpu_trainer_single(): - trainer = Trainer(accelerator="tpu", devices=1) - assert isinstance(trainer.accelerator, TPUAccelerator), "Expected device to be TPU" - - -@pytest.mark.skipif(not os.getenv("FLASH_RUN_TPU_TESTS", "0") == "1", reason="Should run with tpu test") -def test_tpu_trainer_multi_core(): - trainer = Trainer(accelerator="tpu", devices=8) - assert isinstance(trainer.accelerator, TPUAccelerator), "Expected device to be TPU" diff --git a/tests/tpu/test_tpu_multi_core.py b/tests/tpu/test_tpu_multi_core.py deleted file mode 100644 index c93328c596..0000000000 --- a/tests/tpu/test_tpu_multi_core.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import pytest -import torch.nn.functional as F -from pytorch_lightning.accelerators.tpu import TPUAccelerator -from torch.utils.data import DataLoader - -import flash -from tests.core.test_finetuning import DummyDataset, TestTaskWithFinetuning -from tests.tpu.test_tpu_single_core import _assert_state_finished - -# Current state of TPU with Flash (as of v0.8 release) -# Multi Core: -# TPU Training, Validation are supported, but prediction is not. - - -@pytest.mark.skipif(not os.getenv("FLASH_RUN_TPU_TESTS", "0") == "1", reason="Should run with TPU test") -def test_tpu_finetuning(): - task = TestTaskWithFinetuning(loss_fn=F.nll_loss) - - trainer = flash.Trainer(max_epochs=1, devices=8, accelerator="tpu") - assert isinstance(trainer.accelerator, TPUAccelerator) - - dataloader = DataLoader(DummyDataset()) - trainer.finetune(model=task, train_dataloader=dataloader) - _assert_state_finished(trainer, "fit") - - -@pytest.mark.skipif(not os.getenv("FLASH_RUN_TPU_TESTS", "0") == "1", reason="Should run with TPU test") -def test_tpu_prediction(): - task = TestTaskWithFinetuning(loss_fn=F.nll_loss) - dataloader = DataLoader(DummyDataset()) - - trainer = flash.Trainer(max_epochs=1, devices=8, accelerator="tpu") - assert isinstance(trainer.accelerator, TPUAccelerator) - - trainer.fit(model=task, train_dataloader=dataloader, val_dataloaders=dataloader) - _assert_state_finished(trainer, "fit") - - with pytest.raises(NotImplementedError, match="not supported"): - trainer.predict(model=task, dataloaders=dataloader) - return diff --git a/tests/tpu/test_tpu_single_core.py b/tests/tpu/test_tpu_single_core.py deleted file mode 100644 index 1cafe10398..0000000000 --- a/tests/tpu/test_tpu_single_core.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import pytest -import torch.nn.functional as F -from pytorch_lightning.accelerators.tpu import TPUAccelerator -from torch.utils.data import DataLoader - -import flash -from tests.core.test_finetuning import DummyDataset, TestTaskWithFinetuning -from tests.helpers.boring_model import BoringDataModule, BoringModel - -# Current state of TPU with Flash (as of v0.8 release) -# Single Core: -# TPU Training, Validation, and Prediction are supported. - - -# Helper function -def _assert_state_finished(trainer, fn_name): - assert trainer.state.finished and trainer.state.fn == fn_name - - -@pytest.mark.skipif(not os.getenv("FLASH_RUN_TPU_TESTS", "0") == "1", reason="Should run with TPU test") -def test_tpu_finetuning(): - task = TestTaskWithFinetuning(loss_fn=F.nll_loss) - - trainer = flash.Trainer(max_epochs=1, devices=1, accelerator="tpu") - assert isinstance(trainer.accelerator, TPUAccelerator) - - dataloader = DataLoader(DummyDataset()) - trainer.finetune(model=task, train_dataloader=dataloader) - _assert_state_finished(trainer, "fit") - - -@pytest.mark.skipif(not os.getenv("FLASH_RUN_TPU_TESTS", "0") == "1", reason="Should run with TPU test") -def test_tpu_prediction(): - boring_model = BoringModel() - boring_dm = BoringDataModule() - - trainer = flash.Trainer(fast_dev_run=True, devices=1, accelerator="tpu") - assert isinstance(trainer.accelerator, TPUAccelerator) - - trainer.fit(model=boring_model, datamodule=boring_dm) - _assert_state_finished(trainer, "fit") - trainer.validate(model=boring_model, datamodule=boring_dm) - _assert_state_finished(trainer, "validate") - trainer.test(model=boring_model, datamodule=boring_dm) - _assert_state_finished(trainer, "test") - - predictions = trainer.predict(model=boring_model, datamodule=boring_dm) - assert predictions is not None and len(predictions) != 0, "Prediction not successful" - _assert_state_finished(trainer, "predict") diff --git a/tests/tpu_tests.sh b/tests/tpu_tests.sh deleted file mode 100755 index 17b995d0f0..0000000000 --- a/tests/tpu_tests.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/bin/bash -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -set -e - -# this environment variable allows TPU tests to run -export FLASH_RUN_TPU_TESTS=1 -# python arguments -defaults='-m coverage run --source flash --append -m pytest --durations=0 --capture=no --disable-warnings' - -# TODO: In future, we can use RunIf from PL upstream -grep_output=$(grep --recursive --line-number --word-regexp 'tpu' --regexp 'os.getenv("FLASH_RUN_TPU_TESTS",') -# file paths -files=$(echo "$grep_output" | cut -f1 -d:) -files_arr=($files) -echo $files - -# line numbers -linenos=$(echo "$grep_output" | cut -f2 -d:) -linenos_arr=($linenos) - -# tests to skip - space separated -blocklist='test_pytorch_profiler_nested_emit_nvtx' -report='' - -for i in "${!files_arr[@]}"; do - file=${files_arr[$i]} - lineno=${linenos_arr[$i]} - - # get code from `@RunIf(special=True)` line to EOF - test_code=$(tail -n +"$lineno" "$file") - - # read line by line - while read -r line; do - # if it's a test - if [[ $line == def\ test_* ]]; then - # get the name - test_name=$(echo $line | cut -c 5- | cut -f1 -d\() - - # check blocklist - if echo $blocklist | grep --word-regexp "$test_name" > /dev/null; then - report+="Skipped\t$file:$lineno::$test_name\n" - break - fi - - # SPECIAL_PATTERN allows filtering the tests to run when debugging. - # use as `SPECIAL_PATTERN="foo_bar" ./special_tests.sh` to run only those - # test with `foo_bar` in their name - if [[ $line != *$SPECIAL_PATTERN* ]]; then - report+="Skipped\t$file:$lineno::$test_name\n" - break - fi - - # run the test - report+="Ran\t$file:$lineno::$test_name\n" - python ${defaults} "${file}::${test_name}" - break - fi - done < <(echo "$test_code") -done - -# echo test report -printf '=%.s' {1..80} -printf "\n$report" -printf '=%.s' {1..80} -printf '\n' From c4b9f4bca258b3a1d96fe8ad3561808e1d118586 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 10 May 2023 22:08:30 +0200 Subject: [PATCH 25/39] drop zero_requirements --- zero_requirements/image_classification.txt | 7 ------- zero_requirements/tabular_classification.txt | 3 --- zero_requirements/text_classification.txt | 5 ----- 3 files changed, 15 deletions(-) delete mode 100644 zero_requirements/image_classification.txt delete mode 100644 zero_requirements/tabular_classification.txt delete mode 100644 zero_requirements/text_classification.txt diff --git a/zero_requirements/image_classification.txt b/zero_requirements/image_classification.txt deleted file mode 100644 index 0addd5c433..0000000000 --- a/zero_requirements/image_classification.txt +++ /dev/null @@ -1,7 +0,0 @@ -torchvision -timm>=0.4.5 -lightning-bolts>=0.3.3 -Pillow>=7.2 -kornia>=0.5.1 -pystiche==1.* -segmentation-models-pytorch diff --git a/zero_requirements/tabular_classification.txt b/zero_requirements/tabular_classification.txt deleted file mode 100644 index bbd5720096..0000000000 --- a/zero_requirements/tabular_classification.txt +++ /dev/null @@ -1,3 +0,0 @@ -pytorch-tabnet==3.1 -scikit-learn -pytorch-forecasting diff --git a/zero_requirements/text_classification.txt b/zero_requirements/text_classification.txt deleted file mode 100644 index aba24a7ef5..0000000000 --- a/zero_requirements/text_classification.txt +++ /dev/null @@ -1,5 +0,0 @@ -sentencepiece>=0.1.95 -filelock -transformers>=4.5 -torchmetrics[text]>=0.5.1 -datasets>=1.8,<1.13 From 976160669ce5e5db01049fde9a83d643865af195 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 10 May 2023 23:51:37 +0200 Subject: [PATCH 26/39] fixing cli & reorganize examples (#1549) --- .azure/gpu-special-tests.yml | 72 ++++++++------ .github/workflows/ci-checks.yml | 2 +- .github/workflows/ci-testing.yml | 33 +++---- docs/source/integrations/baal.rst | 2 +- docs/source/integrations/fiftyone.rst | 6 +- docs/source/integrations/learn2learn.rst | 2 +- .../integrations/pytorch_forecasting.rst | 2 +- .../source/reference/audio_classification.rst | 2 +- .../source/reference/graph_classification.rst | 2 +- docs/source/reference/graph_embedder.rst | 2 +- .../source/reference/image_classification.rst | 2 +- .../image_classification_multi_label.rst | 2 +- docs/source/reference/image_embedder.rst | 2 +- .../reference/instance_segmentation.rst | 2 +- docs/source/reference/keypoint_detection.rst | 2 +- docs/source/reference/object_detection.rst | 2 +- .../reference/pointcloud_object_detection.rst | 2 +- .../reference/pointcloud_segmentation.rst | 2 +- docs/source/reference/question_answering.rst | 2 +- .../reference/semantic_segmentation.rst | 2 +- docs/source/reference/speech_recognition.rst | 2 +- docs/source/reference/style_transfer.rst | 2 +- docs/source/reference/summarization.rst | 2 +- .../reference/tabular_classification.rst | 2 +- docs/source/reference/tabular_forecasting.rst | 2 +- docs/source/reference/text_classification.rst | 2 +- .../text_classification_multi_label.rst | 2 +- docs/source/reference/text_embedder.rst | 2 +- docs/source/reference/translation.rst | 2 +- .../source/reference/video_classification.rst | 2 +- examples/{ => audio}/audio_classification.py | 0 examples/{ => audio}/speech_recognition.py | 0 examples/{ => graph}/graph_classification.py | 0 examples/{ => graph}/graph_embedder.py | 0 ...aal_img_classification_active_learning.py} | 0 examples/{ => image}/face_detection.py | 0 .../fiftyone_img_classification.py} | 0 .../fiftyone_img_classification_datasets.py} | 0 .../fiftyone_img_embedding.py} | 0 .../fiftyone_object_detection.py} | 0 examples/{ => image}/image_classification.py | 0 .../image_classification_multi_label.py | 0 examples/{ => image}/image_embedder.py | 0 examples/{ => image}/instance_segmentation.py | 0 examples/{ => image}/keypoint_detection.py | 0 .../labelstudio_img_classification.py} | 0 ...rn2learn_img_classification_imagenette.py} | 0 examples/{ => image}/object_detection.py | 0 examples/{ => image}/semantic_segmentation.py | 0 examples/{ => image}/style_transfer.py | 0 .../pcloud_detection.py} | 0 .../pcloud_segmentation.py} | 0 .../visual_detection.py} | 0 .../visual_segmentation.py} | 0 .../forecasting_interpretable.py} | 0 .../{ => tabular}/tabular_classification.py | 0 examples/{ => tabular}/tabular_forecasting.py | 0 examples/{ => tabular}/tabular_regression.py | 0 .../labelstudio_text_classification.py} | 0 examples/{ => text}/question_answering.py | 0 examples/{ => text}/summarization.py | 0 examples/{ => text}/text_classification.py | 0 .../text_classification_multi_label.py | 0 examples/{ => text}/text_embedder.py | 0 examples/{ => text}/translation.py | 0 .../labelstudio_classification.py} | 0 examples/{ => video}/video_classification.py | 0 requirements.txt | 10 +- src/flash/core/classification.py | 2 +- src/flash/core/utilities/lightning_cli.py | 4 +- tests/core/data/test_callback.py | 1 - tests/core/test_model.py | 5 +- tests/core/utilities/test_lightning_cli.py | 19 ++-- tests/examples/{utils.py => helpers.py} | 0 tests/examples/test_integrations.py | 69 -------------- tests/examples/test_scripts.py | 94 +++++++++++++++---- tests/template/classification/test_model.py | 1 + 77 files changed, 183 insertions(+), 187 deletions(-) rename examples/{ => audio}/audio_classification.py (100%) rename examples/{ => audio}/speech_recognition.py (100%) rename examples/{ => graph}/graph_classification.py (100%) rename examples/{ => graph}/graph_embedder.py (100%) rename examples/{integrations/baal/image_classification_active_learning.py => image/baal_img_classification_active_learning.py} (100%) rename examples/{ => image}/face_detection.py (100%) rename examples/{integrations/fiftyone/image_classification.py => image/fiftyone_img_classification.py} (100%) rename examples/{integrations/fiftyone/image_classification_fiftyone_datasets.py => image/fiftyone_img_classification_datasets.py} (100%) rename examples/{integrations/fiftyone/image_embedding.py => image/fiftyone_img_embedding.py} (100%) rename examples/{integrations/fiftyone/object_detection.py => image/fiftyone_object_detection.py} (100%) rename examples/{ => image}/image_classification.py (100%) rename examples/{ => image}/image_classification_multi_label.py (100%) rename examples/{ => image}/image_embedder.py (100%) rename examples/{ => image}/instance_segmentation.py (100%) rename examples/{ => image}/keypoint_detection.py (100%) rename examples/{integrations/labelstudio/image_classification.py => image/labelstudio_img_classification.py} (100%) rename examples/{integrations/learn2learn/image_classification_imagenette_mini.py => image/learn2learn_img_classification_imagenette.py} (100%) rename examples/{ => image}/object_detection.py (100%) rename examples/{ => image}/semantic_segmentation.py (100%) rename examples/{ => image}/style_transfer.py (100%) rename examples/{pointcloud_detection.py => pointcloud/pcloud_detection.py} (100%) rename examples/{pointcloud_segmentation.py => pointcloud/pcloud_segmentation.py} (100%) rename examples/{visualizations/pointcloud_detection.py => pointcloud/visual_detection.py} (100%) rename examples/{visualizations/pointcloud_segmentation.py => pointcloud/visual_segmentation.py} (100%) rename examples/{integrations/pytorch_forecasting/tabular_forecasting_interpretable.py => tabular/forecasting_interpretable.py} (100%) rename examples/{ => tabular}/tabular_classification.py (100%) rename examples/{ => tabular}/tabular_forecasting.py (100%) rename examples/{ => tabular}/tabular_regression.py (100%) rename examples/{integrations/labelstudio/text_classification.py => text/labelstudio_text_classification.py} (100%) rename examples/{ => text}/question_answering.py (100%) rename examples/{ => text}/summarization.py (100%) rename examples/{ => text}/text_classification.py (100%) rename examples/{ => text}/text_classification_multi_label.py (100%) rename examples/{ => text}/text_embedder.py (100%) rename examples/{ => text}/translation.py (100%) rename examples/{integrations/labelstudio/video_classification.py => video/labelstudio_classification.py} (100%) rename examples/{ => video}/video_classification.py (100%) rename tests/examples/{utils.py => helpers.py} (100%) delete mode 100644 tests/examples/test_integrations.py diff --git a/.azure/gpu-special-tests.yml b/.azure/gpu-special-tests.yml index f7c26da082..02b64bb12c 100644 --- a/.azure/gpu-special-tests.yml +++ b/.azure/gpu-special-tests.yml @@ -25,43 +25,53 @@ jobs: container: # image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11" image: "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime" - options: "--ipc=host --gpus=all" + options: "--ipc=host --gpus=all -v /usr/bin/docker:/tmp/docker:ro" workspace: clean: all steps: - - bash: echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" - displayName: 'set visible devices' + - bash: | + echo "##vso[task.setvariable variable=CUDA_VISIBLE_DEVICES]$(DEVICES)" + echo "##vso[task.setvariable variable=CONTAINER_ID]$(head -1 /proc/self/cgroup|cut -d/ -f3)" + displayName: 'Set environment variables' - - bash: | - echo $CUDA_VISIBLE_DEVICES - lspci | egrep 'VGA|3D' - whereis nvidia - nvidia-smi - python --version - pip --version - pip list - df -kh /dev/shm - displayName: 'Image info & NVIDIA' + - script: | + /tmp/docker exec -t -u 0 $CONTAINER_ID \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -o Dpkg::Options::="--force-confold" -y install sudo" + displayName: 'Install Sudo in container (thanks Microsoft!)' - - bash: | - python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" - displayName: 'Sanity check' + - bash: | + echo $CUDA_VISIBLE_DEVICES + lspci | egrep 'VGA|3D' + whereis nvidia + nvidia-smi + python --version + pip --version + pip list + df -kh /dev/shm + displayName: 'Image info & NVIDIA' - - bash: | - pip install '.[image,test]' -r requirements/testing_image.txt -U - pip list - env: - FREEZE_REQUIREMENTS: 1 - displayName: 'Install dependencies' + - bash: | + python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" + displayName: 'Sanity check' - - bash: | - bash tests/special_tests.sh - displayName: 'Testing: special' + - script: | + sudo apt-get install -y build-essential gcc cmake software-properties-common + python -m pip install "pip==22.2.1" + pip --version + pip install '.[image,test]' -r requirements/testing_image.txt -U + pip list + env: + FREEZE_REQUIREMENTS: 1 + displayName: 'Install dependencies' - - bash: | - python -m coverage report - python -m coverage xml - # python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure - ls -l - displayName: 'Statistics' + - bash: | + bash tests/special_tests.sh + displayName: 'Testing: special' + + - bash: | + python -m coverage report + python -m coverage xml + # python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure + ls -l + displayName: 'Statistics' diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 66add3768c..8c79d3c7fa 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -10,7 +10,7 @@ jobs: check-schema: uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.8.0 with: - azure-dir: '' # ToDo + azure-dir: '.azure' check-package: uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.8.0 diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index da2b686559..d0c848a50a 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -25,28 +25,23 @@ jobs: matrix: # PyTorch 1.5 is failing on Win and bolts requires torchvision>=0.5 os: [ubuntu-20.04, macOS-12, windows-2022] - python-version: [3.7, 3.9] - requires: ['oldest', 'latest'] + python-version: [3.8, 3.9] topic: ['core'] extra: [[]] - exclude: - # Skip if torch<1.8 and py3.9 on Linux: https://github.com/pytorch/pytorch/issues/50014 - - { python-version: 3.9, requires: 'oldest' } - - { os: 'macOS-12', requires: 'oldest' } - - { os: 'windows-2022', requires: 'oldest' } include: - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'core', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extra']} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_baal']} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_segm']} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_vissl']} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'video', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'tabular', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'text', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'pointcloud', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'serve', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'graph', extra: []} - - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'audio', extra: []} + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'core', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_extra'] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_baal'] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_segm'] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'image', extra: ['image_vissl'] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'video', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'tabular', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'text', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'pointcloud', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'serve', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'graph', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'audio', extra: [] } + - { os: 'ubuntu-20.04', python-version: 3.7, topic: 'core', extra: [], requires: 'oldest' } # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 50 diff --git a/docs/source/integrations/baal.rst b/docs/source/integrations/baal.rst index 8088729f84..5477e69a4f 100644 --- a/docs/source/integrations/baal.rst +++ b/docs/source/integrations/baal.rst @@ -29,6 +29,6 @@ The most uncertain samples will be labelled by the human to accelerate the model With its integration within Flash, the Active Learning process is simpler than ever before. -.. literalinclude:: ../../../examples/integrations/baal/image_classification_active_learning.py +.. literalinclude:: ../../../examples/image/baal_img_classification_active_learning.py :language: python :lines: 14- diff --git a/docs/source/integrations/fiftyone.rst b/docs/source/integrations/fiftyone.rst index 2b1fda86d1..7a46166b22 100644 --- a/docs/source/integrations/fiftyone.rst +++ b/docs/source/integrations/fiftyone.rst @@ -57,7 +57,7 @@ dictionaries containing :ref:`FiftyOne Label ` objects and filepaths, which is exactly the output of the FiftyOne outputs when the ``return_filepath=True`` option is specified. -.. literalinclude:: ../../../examples/integrations/fiftyone/image_classification.py +.. literalinclude:: ../../../examples/image/fiftyone_img_classification.py :language: python :lines: 14- @@ -94,7 +94,7 @@ method allows you to load your FiftyOne datasets directly into a :class:`~flash.core.data.data_module.DataModule` to be used for training, testing, or inference. -.. literalinclude:: ../../../examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +.. literalinclude:: ../../../examples/image/fiftyone_img_classification_datasets.py :language: python :lines: 14- @@ -109,7 +109,7 @@ FiftyOne provides the methods for powerful workflows like clustering, similarity search, pre-annotation, and more in only a few lines of code. -.. literalinclude:: ../../../examples/integrations/fiftyone/image_embedding.py +.. literalinclude:: ../../../examples/image/fiftyone_img_embedding.py :language: python :lines: 14- diff --git a/docs/source/integrations/learn2learn.rst b/docs/source/integrations/learn2learn.rst index f16ccffe00..79b2e8ac3b 100644 --- a/docs/source/integrations/learn2learn.rst +++ b/docs/source/integrations/learn2learn.rst @@ -72,7 +72,7 @@ Once done, the users are left to play the hyper-parameters associated with the m Here is an example using `miniImageNet dataset `_ containing 100 classes divided into 64 training, 16 validation, and 20 test classes. -.. literalinclude:: ../../../examples/integrations/learn2learn/image_classification_imagenette_mini.py +.. literalinclude:: ../../../examples/image/learn2learn_img_classification_imagenette.py :language: python :lines: 15- diff --git a/docs/source/integrations/pytorch_forecasting.rst b/docs/source/integrations/pytorch_forecasting.rst index 49776a2dab..f9eed0d9b9 100644 --- a/docs/source/integrations/pytorch_forecasting.rst +++ b/docs/source/integrations/pytorch_forecasting.rst @@ -13,7 +13,7 @@ With these, you can train your model and perform inference using Flash but still Here's an example, plotting the predictions and interpretation analysis from the NBeats model trained in the :ref:`tabular_forecasting` documentation: -.. literalinclude:: ../../../examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py +.. literalinclude:: ../../../examples/tabular/forecasting_interpretable.py :language: python :lines: 14- diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst index ffc9dfc4d1..8f4bb199df 100644 --- a/docs/source/reference/audio_classification.rst +++ b/docs/source/reference/audio_classification.rst @@ -73,7 +73,7 @@ We then use the trained :class:`~flash.image.classification.model.ImageClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/audio_classification.py +.. literalinclude:: ../../../examples/audio/audio_classification.py :language: python :lines: 14- diff --git a/docs/source/reference/graph_classification.rst b/docs/source/reference/graph_classification.rst index 92fde7042a..0f5779856f 100644 --- a/docs/source/reference/graph_classification.rst +++ b/docs/source/reference/graph_classification.rst @@ -34,7 +34,7 @@ Next, we use the trained :class:`~flash.graph.classification.model.GraphClassifi Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/graph_classification.py +.. literalinclude:: ../../../examples/graph/graph_classification.py :language: python :lines: 14- diff --git a/docs/source/reference/graph_embedder.rst b/docs/source/reference/graph_embedder.rst index c57323fb8a..2d0dd9f485 100644 --- a/docs/source/reference/graph_embedder.rst +++ b/docs/source/reference/graph_embedder.rst @@ -23,7 +23,7 @@ Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/graph_embedder.py +.. literalinclude:: ../../../examples/graph/graph_embedder.py :language: python :lines: 14 diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index 098f03b9ef..27e38c175c 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -56,7 +56,7 @@ We then use the trained :class:`~flash.image.classification.model.ImageClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/image_classification.py +.. literalinclude:: ../../../examples/image/image_classification.py :language: python :lines: 14- diff --git a/docs/source/reference/image_classification_multi_label.rst b/docs/source/reference/image_classification_multi_label.rst index fc2b42889f..bf8cc488cb 100644 --- a/docs/source/reference/image_classification_multi_label.rst +++ b/docs/source/reference/image_classification_multi_label.rst @@ -50,7 +50,7 @@ We then use the trained :class:`~flash.image.classification.model.ImageClassifie Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/image_classification_multi_label.py +.. literalinclude:: ../../../examples/image/image_classification_multi_label.py :language: python :lines: 14- diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index dc2e66ac20..d84d8b1f13 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -40,7 +40,7 @@ Next, we configure the :class:`~flash.image.embedding.model.ImageEmbedder` task Finally, we construct a :class:`~flash.core.trainer.Trainer` and call ``fit()``. Here's the full example: -.. literalinclude:: ../../../examples/image_embedder.py +.. literalinclude:: ../../../examples/image/image_embedder.py :language: python :lines: 14- diff --git a/docs/source/reference/instance_segmentation.rst b/docs/source/reference/instance_segmentation.rst index 9d9ba82218..cefedd3d1e 100644 --- a/docs/source/reference/instance_segmentation.rst +++ b/docs/source/reference/instance_segmentation.rst @@ -31,7 +31,7 @@ We then use the trained :class:`~flash.image.instance_segmentation.model.Instanc Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/instance_segmentation.py +.. literalinclude:: ../../../examples/image/instance_segmentation.py :language: python :lines: 14- diff --git a/docs/source/reference/keypoint_detection.rst b/docs/source/reference/keypoint_detection.rst index 692451220a..45b9646b4c 100644 --- a/docs/source/reference/keypoint_detection.rst +++ b/docs/source/reference/keypoint_detection.rst @@ -31,7 +31,7 @@ We then use the trained :class:`~flash.image.keypoint_detection.model.KeypointDe Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/keypoint_detection.py +.. literalinclude:: ../../../examples/image/keypoint_detection.py :language: python :lines: 14- diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 4aab8e3d5e..96180a5d7f 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -51,7 +51,7 @@ We then use the trained :class:`~flash.image.detection.model.ObjectDetector` for Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/object_detection.py +.. literalinclude:: ../../../examples/image/object_detection.py :language: python :lines: 14- diff --git a/docs/source/reference/pointcloud_object_detection.rst b/docs/source/reference/pointcloud_object_detection.rst index ebb85302b3..bdccf8ffdd 100644 --- a/docs/source/reference/pointcloud_object_detection.rst +++ b/docs/source/reference/pointcloud_object_detection.rst @@ -80,7 +80,7 @@ We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDet Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/pointcloud_detection.py +.. literalinclude:: ../../../examples/pointcloud/pcloud_detection.py :language: python :lines: 14- diff --git a/docs/source/reference/pointcloud_segmentation.rst b/docs/source/reference/pointcloud_segmentation.rst index 6b441f2211..811ed6d31d 100644 --- a/docs/source/reference/pointcloud_segmentation.rst +++ b/docs/source/reference/pointcloud_segmentation.rst @@ -71,7 +71,7 @@ We then use the trained ``PointCloudSegmentation`` for inference. Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/pointcloud_segmentation.py +.. literalinclude:: ../../../examples/pointcloud/pcloud_segmentation.py :language: python :lines: 14- diff --git a/docs/source/reference/question_answering.rst b/docs/source/reference/question_answering.rst index 29b710181b..e62ee4543c 100644 --- a/docs/source/reference/question_answering.rst +++ b/docs/source/reference/question_answering.rst @@ -60,7 +60,7 @@ Next, we use the trained :class:`~flash.text.question_answering.model.QuestionAn Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/question_answering.py +.. literalinclude:: ../../../examples/text/question_answering.py :language: python :lines: 14- diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 1f5738629c..9a880d6d80 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -45,7 +45,7 @@ We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmenta Finally, we save the model. Here's the full example: -.. literalinclude:: ../../../examples/semantic_segmentation.py +.. literalinclude:: ../../../examples/image/semantic_segmentation.py :language: python :lines: 14- diff --git a/docs/source/reference/speech_recognition.rst b/docs/source/reference/speech_recognition.rst index f9a8753f48..b61ecd0e87 100644 --- a/docs/source/reference/speech_recognition.rst +++ b/docs/source/reference/speech_recognition.rst @@ -49,7 +49,7 @@ The backbone can be any Wav2Vec model from `HuggingFace transformers =1.7.1 -torchmetrics >=0.7.0, <0.11.0 # strict -pytorch-lightning >=1.6.0, <1.9.0 # strict -pyDeprecate -pandas >=1.1.0, <=1.5.2 +torch >1.7.0 +torchmetrics >0.7.0, <0.11.0 # strict +pytorch-lightning >1.6.0, <1.9.0 # strict +pyDeprecate >0.1.0 +pandas >1.1.0, <=1.5.2 jsonargparse[signatures] >=3.17.0, <=4.9.0 click >=7.1.2, <=8.1.3 protobuf <=3.20.1 diff --git a/src/flash/core/classification.py b/src/flash/core/classification.py index 25973d79f1..f8c86fe844 100644 --- a/src/flash/core/classification.py +++ b/src/flash/core/classification.py @@ -60,7 +60,7 @@ def _build( if metrics is None: metrics = ( - F1Score(num_classes=num_classes, task="multilabel", top_k=1) + F1Score(num_labels=num_classes, task="multilabel", top_k=1) if (multi_label and num_classes) else Accuracy() ) diff --git a/src/flash/core/utilities/lightning_cli.py b/src/flash/core/utilities/lightning_cli.py index 864e047642..37ce4a470e 100644 --- a/src/flash/core/utilities/lightning_cli.py +++ b/src/flash/core/utilities/lightning_cli.py @@ -12,10 +12,8 @@ from jsonargparse import ActionConfigFile, ArgumentParser, set_config_read_mode from jsonargparse.signatures import ClassFromFunctionBase from jsonargparse.typehints import ClassType +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import seed_everything diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index 73d8803436..cac8ec6bd8 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -75,7 +75,6 @@ def test_step(self, batch, batch_idx): max_epochs=1, limit_val_batches=1, limit_train_batches=1, - progress_bar_refresh_rate=0, ) transform = InputTransform() dm = DataModule( diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8cb1fd5349..1a136fe289 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -235,7 +235,10 @@ def test_classification_task_trainer_predict(tmpdir): pytest.param( TabularClassifier, "0.7.0/tabular_classification_model.pt", - marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular packages aren't installed"), + marks=[ + pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular packages aren't installed"), + pytest.mark.xfail(RuntimeError, reason="upgraded Tabular to 1.0"), + ], ), pytest.param( TextClassifier, diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index e40d3e946f..a337125bb4 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -133,6 +133,7 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): ("--gpus 0,1", [0, 1]), ], ) +@pytest.mark.xfail(strict=False, reason="mocking does not work as expected") # fixme def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" mocker.patch("lightning_lite.utilities.device_parser._get_all_available_gpus", return_value=[0, 1]) @@ -297,7 +298,7 @@ def test_lightning_cli_args(tmpdir): f"--data.data_dir={tmpdir}", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", - "--trainer.weights_summary=null", + "--trainer.enable_model_summary=false", "--seed_everything=1234", ] @@ -344,7 +345,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): config = dict( model=dict(class_path="tests.helpers.boring_model.BoringModel"), data=dict(class_path="tests.helpers.boring_model.BoringDataModule", init_args=dict(data_dir=str(tmpdir))), - trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None), + trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, enable_model_summary=False), ) config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: @@ -643,7 +644,7 @@ def add_arguments_to_parser(self, parser): assert cli.model.configure_optimizers is not BoringModel.configure_optimizers assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 0 + assert len(cli.trainer.lr_scheduler_configs) == 0 @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @@ -665,9 +666,9 @@ def add_arguments_to_parser(self, parser): assert cli.model.configure_optimizers is not BoringModel.configure_optimizers assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR) - assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 + assert len(cli.trainer.lr_scheduler_configs) == 1 + assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.ExponentialLR) + assert cli.trainer.lr_scheduler_configs[0].scheduler.gamma == 0.8 @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @@ -697,9 +698,9 @@ def add_arguments_to_parser(self, parser): assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) - assert len(cli.trainer.lr_schedulers) == 1 - assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) - assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 + assert len(cli.trainer.lr_scheduler_configs) == 1 + assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR) + assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50 class TestModelOptLR(BoringModel): diff --git a/tests/examples/utils.py b/tests/examples/helpers.py similarity index 100% rename from tests/examples/utils.py rename to tests/examples/helpers.py diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py deleted file mode 100644 index 8b4f521619..0000000000 --- a/tests/examples/test_integrations.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from pathlib import Path -from unittest import mock - -import pytest - -from flash.core.utilities.imports import ( - _BAAL_AVAILABLE, - _FIFTYONE_AVAILABLE, - _LEARN2LEARN_AVAILABLE, - _TOPIC_IMAGE_AVAILABLE, -) -from tests.examples.utils import run_test - -root = Path(__file__).parent.parent.parent - - -@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) -@pytest.mark.parametrize( - "folder, file", - [ - pytest.param( - "fiftyone", - "image_classification.py", - marks=pytest.mark.skipif( - not (_TOPIC_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" - ), - ), - pytest.param( - "fiftyone", - "object_detection.py", - marks=pytest.mark.skipif( - not (_TOPIC_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" - ), - ), - pytest.param( - "baal", - "image_classification_active_learning.py", - marks=pytest.mark.skipif( - not (_TOPIC_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="baal library isn't installed" - ), - ), - pytest.param( - "learn2learn", - "image_classification_imagenette_mini.py", - marks=[ - pytest.mark.skip("MiniImagenet broken: https://github.com/learnables/learn2learn/issues/291"), - pytest.mark.skipif( - not (_TOPIC_IMAGE_AVAILABLE and _LEARN2LEARN_AVAILABLE), reason="learn2learn isn't installed" - ), - ], - ), - ], -) -def test_integrations(tmpdir, folder, file): - run_test(str(root / "examples" / "integrations" / folder / file)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 2f8202b252..e37d1a039b 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -20,8 +20,11 @@ import torch from flash.core.utilities.imports import ( + _BAAL_AVAILABLE, + _FIFTYONE_AVAILABLE, _ICEDATA_AVAILABLE, _ICEVISION_AVAILABLE, + _LEARN2LEARN_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_AUDIO_AVAILABLE, _TOPIC_CORE_AVAILABLE, @@ -34,7 +37,7 @@ _TORCHVISION_GREATER_EQUAL_0_9, _VISSL_AVAILABLE, ) -from tests.examples.utils import run_test +from tests.examples.helpers import run_test from tests.helpers.decorators import forked root = Path(__file__).parent.parent.parent @@ -42,25 +45,39 @@ @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "file", + "folder,fname", [ pytest.param( + "", + "template.py", + marks=[ + pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core."), + pytest.mark.skipif(os.name == "posix", reason="Flaky on Mac OS (CI)"), + pytest.mark.skipif(sys.version_info >= (3, 9), reason="Undiagnosed segmentation fault in 3.9"), + ], + ), + pytest.param( + "audio", "audio_classification.py", marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed"), ), pytest.param( + "audio", "speech_recognition.py", marks=pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed"), ), pytest.param( + "image", "image_classification.py", marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), ), pytest.param( + "image", "image_classification_multi_label.py", marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), ), pytest.param( + "image", "image_embedder.py", marks=[ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), @@ -69,12 +86,14 @@ ], ), pytest.param( + "image", "object_detection.py", marks=pytest.mark.skipif( not (_TOPIC_IMAGE_AVAILABLE and _ICEVISION_AVAILABLE), reason="image libraries aren't installed" ), ), pytest.param( + "image", "instance_segmentation.py", marks=[ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), @@ -83,6 +102,7 @@ ], ), pytest.param( + "image", "keypoint_detection.py", marks=[ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), @@ -90,10 +110,7 @@ ], ), pytest.param( - "question_answering.py", - marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), - ), - pytest.param( + "image", "semantic_segmentation.py", marks=[ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), @@ -102,6 +119,7 @@ ], ), pytest.param( + "image", "style_transfer.py", marks=[ pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"), @@ -109,34 +127,37 @@ ], ), pytest.param( + "text", + "question_answering.py", + marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), + ), + pytest.param( + "text", "summarization.py", marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), pytest.param( + "tabular", "tabular_classification.py", marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed"), ), pytest.param( + "tabular", "tabular_regression.py", marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed"), ), pytest.param( + "tabular", "tabular_forecasting.py", marks=pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed"), ), pytest.param( - "template.py", - marks=[ - pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core."), - pytest.mark.skipif(os.name == "posix", reason="Flaky on Mac OS (CI)"), - pytest.mark.skipif(sys.version_info >= (3, 9), reason="Undiagnosed segmentation fault in 3.9"), - ], - ), - pytest.param( + "text", "text_classification.py", marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), pytest.param( + "text", "text_embedder.py", marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), ), @@ -145,6 +166,7 @@ # marks=pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed") # ), pytest.param( + "text", "translation.py", marks=[ pytest.mark.skipif(not _TOPIC_TEXT_AVAILABLE, reason="text libraries aren't installed"), @@ -152,28 +174,64 @@ ], ), pytest.param( + "video", "video_classification.py", marks=pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="video libraries aren't installed"), ), pytest.param( - "pointcloud_segmentation.py", + "pointcloud", + "pcloud_segmentation.py", marks=pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed"), ), pytest.param( - "pointcloud_detection.py", + "pointcloud", + "pcloud_detection.py", marks=pytest.mark.skipif(not _TOPIC_POINTCLOUD_AVAILABLE, reason="pointcloud libraries aren't installed"), ), pytest.param( + "graph", "graph_classification.py", marks=pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph libraries aren't installed"), ), pytest.param( + "graph", "graph_embedder.py", marks=pytest.mark.skipif(not _TOPIC_GRAPH_AVAILABLE, reason="graph libraries aren't installed"), ), + pytest.param( + "image", + "fiftyone_img_classification.py", + marks=pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" + ), + ), + pytest.param( + "image", + "fiftyone_object_detection.py", + marks=pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed" + ), + ), + pytest.param( + "image", + "baal_img_classification_active_learning.py", + marks=pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="baal library isn't installed" + ), + ), + pytest.param( + "image", + "learn2learn_img_classification_imagenette.py", + marks=[ + pytest.mark.skip("MiniImagenet broken: https://github.com/learnables/learn2learn/issues/291"), + pytest.mark.skipif( + not (_TOPIC_IMAGE_AVAILABLE and _LEARN2LEARN_AVAILABLE), reason="learn2learn isn't installed" + ), + ], + ), ], ) @forked @pytest.mark.skipif(sys.platform == "darwin", reason="Fatal Python error: Illegal instruction") # fixme -def test_example(tmpdir, file): - run_test(str(root / "examples" / file)) +def test_example(folder, fname): + run_test(str(root / "examples" / folder / fname)) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index 41f6d45e09..ca9b7d455a 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -124,6 +124,7 @@ def test_predict_sklearn(): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))]) +@pytest.mark.xfail(RuntimeError, reason="TemplateSKLearnClassifier is not attached to a `Trainer`") # fixme def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "testing_model.pt") From 5c42e23fdf2787f9adad2f6f5d0c098d116d8f61 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 00:47:59 +0200 Subject: [PATCH 27/39] ci: azure schema (#1550) --- .azure/template-examples.yml | 9 --------- .github/workflows/ci-checks.yml | 8 +++++++- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/.azure/template-examples.yml b/.azure/template-examples.yml index ce88118ba6..8b781c0cd8 100644 --- a/.azure/template-examples.yml +++ b/.azure/template-examples.yml @@ -66,12 +66,3 @@ jobs: # python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure ls -l displayName: 'Statistics' - - - task: PublishCodeCoverageResults@1 - displayName: 'Publish coverage report' - inputs: - codeCoverageTool: 'cobertura' - summaryFileLocation: 'coverage.xml' - reportDirectory: '$(Build.SourcesDirectory)/htmlcov' - testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)' - condition: succeededOrFailed() diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 8c79d3c7fa..86e10b29b4 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -10,7 +10,8 @@ jobs: check-schema: uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.8.0 with: - azure-dir: '.azure' + # todo: validation has some problem with `- ${{ each topic in parameters.domains }}:` construct + azure-dir: "" check-package: uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.8.0 @@ -18,3 +19,8 @@ jobs: actions-ref: v0.8.0 artifact-name: dist-packages-${{ github.sha }} import-name: "flash" + testing-matrix: | + { + "os": ["ubuntu-20.04", "macos-11", "windows-2022"], + "python-version": ["3.8"] + } From d3421c0706d9b0aa1801af2e788045315c1b68d5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 14:51:29 +0200 Subject: [PATCH 28/39] resolving tabular task & test serve (#1551) --- .github/workflows/ci-testing.yml | 8 ++- requirements.txt | 2 +- requirements/test.txt | 11 ++-- tests/audio/classification/test_model.py | 4 +- tests/audio/speech_recognition/test_model.py | 4 +- tests/core/data/test_callback.py | 54 ++++++++--------- tests/core/data/test_data_module.py | 6 +- tests/core/test_model.py | 7 +-- tests/core/utilities/test_lightning_cli.py | 60 +++++++++---------- tests/examples/test_scripts.py | 4 +- tests/helpers/task_tester.py | 4 +- tests/image/classification/test_model.py | 4 +- tests/image/detection/test_model.py | 4 +- tests/image/face_detection/test_model.py | 4 +- tests/image/semantic_segm/test_model.py | 4 +- tests/serve/.gitkeep | 0 tests/serve/__init__.py | 1 - .../test_data_model_integration.py | 3 +- tests/tabular/classification/test_model.py | 10 ++-- .../regression/test_data_model_integration.py | 9 ++- tests/tabular/regression/test_model.py | 10 ++-- tests/text/classification/test_model.py | 4 +- .../text/seq2seq/summarization/test_model.py | 4 +- tests/text/seq2seq/translation/test_model.py | 4 +- 24 files changed, 114 insertions(+), 111 deletions(-) create mode 100644 tests/serve/.gitkeep delete mode 100644 tests/serve/__init__.py diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d0c848a50a..80e6b2dee4 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -108,11 +108,13 @@ jobs: key: flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }} restore-keys: flash-datasets- + # ToDO + #- name: DocTests + # run: | + # pytest src/ -vv # --reruns 3 --reruns-delay 2 + - name: Tests - env: - FIFTYONE_DO_NOT_TRACK: true run: | - # FixMe: include doctests for src/ coverage run --source flash -m pytest \ tests/core \ tests/deprecated_api \ diff --git a/requirements.txt b/requirements.txt index 48cf9512c7..d699c72e91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ torchmetrics >0.7.0, <0.11.0 # strict pytorch-lightning >1.6.0, <1.9.0 # strict pyDeprecate >0.1.0 pandas >1.1.0, <=1.5.2 -jsonargparse[signatures] >=3.17.0, <=4.9.0 +jsonargparse[signatures] >4.0.0, <=4.9.0 click >=7.1.2, <=8.1.3 protobuf <=3.20.1 fsspec[http] >=2022.5.0,<=2022.7.1 diff --git a/requirements/test.txt b/requirements/test.txt index bc1e3fb2a2..d5bca03782 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,12 +1,11 @@ # NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup coverage[toml] -codecov >2.1 -pytest >7.2, <7.4 -pytest-doctestplus >0.9.0 -pytest-rerunfailures >10.0 -pytest-forked -pytest-mock +pytest >6.2, <7.0 +pytest-doctestplus >0.12.0 +pytest-rerunfailures >11.0.0 +pytest-forked ==1.6.0 +pytest-mock ==3.10.0 scikit-learn torch_optimizer diff --git a/tests/audio/classification/test_model.py b/tests/audio/classification/test_model.py index 9bd2c83782..89bea0fb97 100644 --- a/tests/audio/classification/test_model.py +++ b/tests/audio/classification/test_model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock +from unittest.mock import patch import pytest @@ -23,7 +23,7 @@ @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_cli(): cli_args = ["flash", "audio_classification", "--trainer.fast_dev_run", "True"] - with mock.patch("sys.argv", cli_args): + with patch("sys.argv", cli_args): try: main() except SystemExit: diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 1271a0acfa..4751c7ad8e 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License.import os from typing import Any -from unittest import mock +from unittest.mock import patch import numpy as np import pytest @@ -68,7 +68,7 @@ def test_modules_to_freeze(): @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = SpeechRecognition(backbone=TEST_BACKBONE) model.eval() diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index cac8ec6bd8..7794298d64 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock +from unittest.mock import ANY, MagicMock, call, patch import pytest import torch @@ -26,12 +26,12 @@ @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error -@mock.patch("torch.save") # need to mock torch.save, or we get pickle error +@patch("pickle.dumps") # need to mock pickle or we get pickle error +@patch("torch.save") # need to mock torch.save, or we get pickle error def test_flash_callback(_, __, tmpdir): """Test the callback hook system for fit.""" - callback_mock = mock.MagicMock() + callback_mock = MagicMock() inputs = [(torch.rand(1), torch.rand(1))] transform = InputTransform() @@ -48,10 +48,10 @@ def test_flash_callback(_, __, tmpdir): _ = next(iter(dm.train_dataloader())) assert callback_mock.method_calls == [ - mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING), - mock.call.on_collate(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_per_sample_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), ] class CustomModel(Task): @@ -89,23 +89,23 @@ def test_step(self, batch, batch_idx): trainer.fit(CustomModel(), datamodule=dm) assert callback_mock.method_calls == [ - mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING), - mock.call.on_collate(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING), - mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING), - mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING), - mock.call.on_collate(mock.ANY, RunningStage.VALIDATING), - mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING), - mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING), - mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING), - mock.call.on_collate(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING), - mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.TRAINING), - mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING), - mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING), - mock.call.on_collate(mock.ANY, RunningStage.VALIDATING), - mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING), - mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING), + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_per_sample_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.VALIDATING), + call.on_per_sample_transform(ANY, RunningStage.VALIDATING), + call.on_collate(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), + call.on_load_sample(ANY, RunningStage.TRAINING), + call.on_per_sample_transform(ANY, RunningStage.TRAINING), + call.on_collate(ANY, RunningStage.TRAINING), + call.on_per_batch_transform(ANY, RunningStage.TRAINING), + call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING), + call.on_load_sample(ANY, RunningStage.VALIDATING), + call.on_per_sample_transform(ANY, RunningStage.VALIDATING), + call.on_collate(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform(ANY, RunningStage.VALIDATING), + call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), ] diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index 46f812bceb..5614e32cf6 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass from typing import Callable, Dict -from unittest import mock +from unittest.mock import MagicMock, NonCallableMock, patch import numpy as np import pytest @@ -426,8 +426,8 @@ def validation_step(self, batch, batch_idx): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("sampler, callable", [(mock.MagicMock(), True), (mock.NonCallableMock(), False)]) -@mock.patch("flash.core.data.data_module.DataLoader") +@pytest.mark.parametrize("sampler, callable", [(MagicMock(), True), (NonCallableMock(), False)]) +@patch("flash.core.data.data_module.DataLoader") def test_dataloaders_with_sampler(mock_dataloader, sampler, callable): train_input = TestInput(RunningStage.TRAINING, [1]) datamodule = DataModule( diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 1a136fe289..110d4f87c3 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -17,8 +17,7 @@ from itertools import chain from numbers import Number from typing import Any, Tuple -from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest import pytorch_lightning as pl @@ -387,12 +386,12 @@ def test_optimizer_learning_rate(): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) ClassificationTask(model, optimizer="test").configure_optimizers() - mock_optimizer.assert_called_once_with(mock.ANY) + mock_optimizer.assert_called_once_with(ANY) mock_optimizer.reset_mock() ClassificationTask(model, optimizer="test", learning_rate=10).configure_optimizers() - mock_optimizer.assert_called_once_with(mock.ANY, lr=10) + mock_optimizer.assert_called_once_with(ANY, lr=10) mock_optimizer.reset_mock() diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index a337125bb4..5602c32621 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -9,7 +9,7 @@ from contextlib import redirect_stdout from io import StringIO from typing import List, Optional, Union -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -36,7 +36,7 @@ @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@mock.patch("argparse.ArgumentParser.parse_args") +@patch("argparse.ArgumentParser.parse_args") def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) @@ -94,7 +94,7 @@ def test_parse_args_parsing(cli_args, expected): cli_args = cli_args.split(" ") if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() for k, v in expected.items(): @@ -115,7 +115,7 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() for k, v in expected.items(): @@ -140,7 +140,7 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): cli_args = cli_args.split(" ") if cli_args else [] parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): args = parser.parse_args() trainer = Trainer.from_argparse_args(args) @@ -165,7 +165,7 @@ def test_init_from_argparse_args(cli_args, extra_args): unknown_args = dict(unknown_arg=0) # unkown args in the argparser/namespace should be ignored - with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init: + with patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init: trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) expected = dict(cli_args) expected.update(extra_args) # extra args should override any cli arg @@ -220,7 +220,7 @@ def on_train_start(callback, trainer, _): monkeypatch.setattr(Trainer, "fit", fit) monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start) - with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]): + with patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]): cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback) assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts @@ -248,7 +248,7 @@ def test_lightning_cli_args_callbacks(tmpdir): dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")), ] - with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]): + with patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]): cli = LightningCLI(TestModelCallbacks, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts @@ -266,7 +266,7 @@ def add_arguments_to_parser(self, parser): "--learning_rate_monitor.logging_interval=epoch", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)] @@ -286,7 +286,7 @@ def on_fit_start(self): def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] - with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): + with patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): cli = LightningCLI(TestModelClusterEnv, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts @@ -302,7 +302,7 @@ def test_lightning_cli_args(tmpdir): "--seed_everything=1234", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]}) assert cli.config["seed_everything"] == 1234 @@ -325,18 +325,18 @@ def test_lightning_cli_save_config_cases(tmpdir): ] # With fast_dev_run!=False config should not be saved - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert not os.path.isfile(config_path) # With fast_dev_run==False config should be saved cli_args[-1] = "--trainer.max_epochs=1" - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert os.path.isfile(config_path) # If run again on same directory exception should be raised since config file already exists - with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError): + with patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError): LightningCLI(BoringModel) @@ -351,7 +351,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): with open(config_path, "w") as f: f.write(yaml.dump(config)) - with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]): + with patch("sys.argv", ["any.py", "--config", str(config_path)]): cli = LightningCLI( BoringModel, BoringDataModule, @@ -382,7 +382,7 @@ def any_model_any_data_cli(): def test_lightning_cli_help(): cli_args = ["any.py", "--help"] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() assert "--print_config" in out.getvalue() @@ -398,7 +398,7 @@ def test_lightning_cli_help(): cli_args = ["any.py", "--data.help=tests.helpers.boring_model.BoringDataModule"] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() assert "--data.init_args.data_dir" in out.getvalue() @@ -415,7 +415,7 @@ def test_lightning_cli_print_config(): ] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() outval = yaml.safe_load(out.getvalue()) @@ -455,7 +455,7 @@ def test_lightning_cli_submodules(tmpdir): f"--config={str(config_path)}", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(MainModule) assert cli.config["model"]["main_param"] == 2 @@ -500,7 +500,7 @@ def test_lightning_cli_torch_modules(tmpdir): f"--config={str(config_path)}", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(TestModuleTorch) assert isinstance(cli.model.activation, torch.nn.LeakyReLU) @@ -543,7 +543,7 @@ def add_arguments_to_parser(self, parser): "--data.batch_size=12", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses) assert cli.model.batch_size == 12 @@ -556,7 +556,7 @@ def add_arguments_to_parser(self, parser): cli_args[-1] = "--model=tests.core.utilities.test_lightning_cli.BoringModelRequiredClasses" - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI( BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, @@ -590,7 +590,7 @@ def on_exception(self, execption): ) @pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): - with mock.patch("sys.argv", ["any.py"]), pytest.raises(CustomException): + with patch("sys.argv", ["any.py"]), pytest.raises(CustomException): LightningCLI( EarlyExitTestModel, trainer_defaults={ @@ -615,11 +615,11 @@ def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): def test_cli_config_overwrite(tmpdir): trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} - with mock.patch("sys.argv", ["any.py"]): + with patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) - with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"): + with patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) - with mock.patch("sys.argv", ["any.py"]): + with patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults) @@ -638,7 +638,7 @@ def add_arguments_to_parser(self, parser): "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.add_configure_optimizers_method_to_model`" ) - with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.warns(UserWarning, match=match): + with patch("sys.argv", ["any.py"] + cli_args), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel) assert cli.model.configure_optimizers is not BoringModel.configure_optimizers @@ -660,7 +660,7 @@ def add_arguments_to_parser(self, parser): "--lr_scheduler.gamma=0.8", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) assert cli.model.configure_optimizers is not BoringModel.configure_optimizers @@ -693,7 +693,7 @@ def add_arguments_to_parser(self, parser): f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) assert len(cli.trainer.optimizers) == 1 @@ -732,7 +732,7 @@ def add_arguments_to_parser(self, parser): "--lr_scheduler.gamma=0.2", ] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(TestModelOptLR) assert isinstance(cli.model.optim1, torch.optim.Adam) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index e37d1a039b..80fea16b43 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -14,7 +14,7 @@ import os import sys from pathlib import Path -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -43,7 +43,7 @@ root = Path(__file__).parent.parent.parent -@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) +@patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( "folder,fname", [ diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py index 3f34b9ecc3..497426af3f 100644 --- a/tests/helpers/task_tester.py +++ b/tests/helpers/task_tester.py @@ -17,7 +17,7 @@ import types from abc import ABCMeta from typing import Any, Dict, List, Optional, Tuple -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -130,7 +130,7 @@ def _test_jit_script(self, tmpdir): def _test_cli(self, extra_args: List): """Tests that the default Flash zero configuration runs for the task.""" cli_args = ["flash", self.cli_command, "--trainer.fast_dev_run", "True"] + extra_args - with mock.patch("sys.argv", cli_args): + with patch("sys.argv", cli_args): try: main() except SystemExit: diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index c72c3ae109..8de7a109ac 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -150,7 +150,7 @@ def test_multilabel(tmpdir): @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = ImageClassifier(2) model.eval() diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index bc2fde9a12..9465b69df0 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -13,7 +13,7 @@ # limitations under the License. import random from typing import Any -from unittest import mock +from unittest.mock import patch import numpy as np import pytest @@ -141,7 +141,7 @@ def test_predict(tmpdir, head): @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = ObjectDetector(2) model.eval() diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py index d57b8b8590..d4f49a6777 100644 --- a/tests/image/face_detection/test_model.py +++ b/tests/image/face_detection/test_model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock +from unittest.mock import patch import pytest @@ -57,7 +57,7 @@ def test_fastface_backbones_registry(): @pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") def test_cli(): cli_args = ["flash", "face_detection", "--trainer.fast_dev_run", "True"] - with mock.patch("sys.argv", cli_args): + with patch("sys.argv", cli_args): try: main() except SystemExit: diff --git a/tests/image/semantic_segm/test_model.py b/tests/image/semantic_segm/test_model.py index 142e64c614..c58b0a1632 100644 --- a/tests/image/semantic_segm/test_model.py +++ b/tests/image/semantic_segm/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import numpy as np import pytest @@ -108,7 +108,7 @@ def test_predict_numpy(): @pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP") @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="some serving") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = SemanticSegmentation(2) model.eval() diff --git a/tests/serve/.gitkeep b/tests/serve/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/serve/__init__.py b/tests/serve/__init__.py deleted file mode 100644 index 81a3bbadf5..0000000000 --- a/tests/serve/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""This is just placeholder to have parity with domains/topics.""" diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index d9c86aa081..010d58b275 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -39,7 +39,8 @@ ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + # ("category_embedding", # todo: seems to be bug in tabular + # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 5b6347dc70..008a797a99 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import pandas as pd import pytest @@ -55,7 +55,7 @@ class TestTabularClassifier(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - {"backbone": "category_embedding"}, + # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular ], ) ], @@ -68,7 +68,7 @@ class TestTabularClassifier(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - {"backbone": "category_embedding"}, + # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular ], ) ], @@ -81,7 +81,7 @@ class TestTabularClassifier(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - {"backbone": "category_embedding"}, + # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular ], ) ], @@ -144,7 +144,7 @@ def test_init_train_no_cat(self, backbone, tmpdir): @pytest.mark.parametrize( "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] ) -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(backbone): train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]} datamodule = TabularClassificationData.from_data_frame( diff --git a/tests/tabular/regression/test_data_model_integration.py b/tests/tabular/regression/test_data_model_integration.py index 3cfeb6fdb2..a2144c8ee9 100644 --- a/tests/tabular/regression/test_data_model_integration.py +++ b/tests/tabular/regression/test_data_model_integration.py @@ -48,7 +48,8 @@ ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + # ("category_embedding", # todo: seems to be bug in tabular + # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), @@ -81,7 +82,8 @@ def test_regression_data_frame(backbone, fields, tmpdir): ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + # ("category_embedding", # todo: seems to be bug in tabular + # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), @@ -111,7 +113,8 @@ def test_regression_dicts(backbone, fields, tmpdir): ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + # ("category_embedding", # todo: seems to be bug in tabular + # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py index 39789d7681..52dd36ed9b 100644 --- a/tests/tabular/regression/test_model.py +++ b/tests/tabular/regression/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import pandas as pd import pytest @@ -53,7 +53,7 @@ class TestTabularRegressor(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - {"backbone": "category_embedding"}, + # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular ], ) ], @@ -66,7 +66,7 @@ class TestTabularRegressor(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - {"backbone": "category_embedding"}, + # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular ], ) ], @@ -79,7 +79,7 @@ class TestTabularRegressor(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - {"backbone": "category_embedding"}, + # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular ], ) ], @@ -142,7 +142,7 @@ def test_init_train_no_cat(self, backbone, tmpdir): @pytest.mark.parametrize( "backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"] ) -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(backbone): train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]} datamodule = TabularRegressionData.from_data_frame( diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 26c71b7778..63830a97df 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -107,7 +107,7 @@ def test_ort_callback_fails_no_model(self, tmpdir): @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = TextClassifier(2, backbone=TEST_BACKBONE) model.eval() diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index bf9c6f4ede..1513cbe5b6 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -62,7 +62,7 @@ def example_test_sample(self): @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = SummarizationTask(TEST_BACKBONE) model.eval() diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index b26e767d9e..e6608de7ec 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any -from unittest import mock +from unittest.mock import patch import pytest import torch @@ -62,7 +62,7 @@ def example_test_sample(self): @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) +@patch("flash._IS_TESTING", True) def test_serve(): model = TranslationTask(TEST_BACKBONE) model.eval() From e10fca584567315f7d50005f8fe90869a1f0b558 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 11 May 2023 14:51:49 +0200 Subject: [PATCH 29/39] update dependabot --- .github/dependabot.yml | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 0cef51216f..2b16b225d3 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,14 +1,38 @@ -# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file +# Basic dependabot.yml file with minimum configuration for two package managers + version: 2 updates: + # Enable version updates for python + - package-ecosystem: "pip" + # Look for a `requirements` in the `root` directory + directory: "/" + # Check for updates once a week + schedule: + interval: "monthly" + # Labels on pull requests for version updates only + labels: ["enhancement"] + pull-request-branch-name: + # Separate sections of the branch name with a hyphen + # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` + separator: "-" + # Allow up to 5 open pull requests for pip dependencies + open-pull-requests-limit: 10 + reviewers: + - "Lightning-Universe/engs" + + # Enable version updates for GitHub Actions - package-ecosystem: "github-actions" directory: "/" + # Check for updates once a week schedule: - interval: "weekly" - labels: - - "tests / CI" + interval: "monthly" + # Labels on pull requests for version updates only + labels: ["tests / CI"] pull-request-branch-name: + # Separate sections of the branch name with a hyphen + # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` separator: "-" + # Allow up to 5 open pull requests for GitHub Actions open-pull-requests-limit: 5 reviewers: - - "Lightning-AI/core-flash" + - "Lightning-Universe/engs" From 8a0a962f42a1cf3c8b20253dc318e63abf551558 Mon Sep 17 00:00:00 2001 From: Izik Golan <47969623+izikgo@users.noreply.github.com> Date: Thu, 11 May 2023 17:31:49 +0300 Subject: [PATCH 30/39] fix channel dim selection on segmentation target (#1509) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka --- src/flash/image/segmentation/input.py | 2 +- tests/image/semantic_segm/test_data.py | 70 ++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/src/flash/image/segmentation/input.py b/src/flash/image/segmentation/input.py index 4662e1c986..02068c96be 100644 --- a/src/flash/image/segmentation/input.py +++ b/src/flash/image/segmentation/input.py @@ -94,7 +94,7 @@ def load_data( def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: if DataKeys.TARGET in sample: - sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0] + sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[0, :, :] return super().load_sample(sample) diff --git a/tests/image/semantic_segm/test_data.py b/tests/image/semantic_segm/test_data.py index 0c07e99fb0..aae5129c90 100644 --- a/tests/image/semantic_segm/test_data.py +++ b/tests/image/semantic_segm/test_data.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, List, Tuple +from typing import Callable, Dict, List, Tuple import numpy as np import pytest @@ -8,6 +8,8 @@ from flash import Trainer from flash.core.data.io.input import DataKeys +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _MATPLOTLIB_AVAILABLE, @@ -43,12 +45,22 @@ def _rand_labels(size: Tuple[int, int], num_classes: int): return Image.fromarray(data.astype(np.uint8)) -def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int): +def create_random_data( + image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int +) -> Tuple[List[Image.Image], List[Image.Image]]: + imgs = [] for img_file in image_files: - _rand_image(size).save(img_file) + img = _rand_image(size) + img.save(img_file) + imgs.append(img) + labels = [] for label_file in label_files: - _rand_labels(size, num_classes).save(label_file) + label = _rand_labels(size, num_classes) + label.save(label_file) + labels.append(label) + + return imgs, labels @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -58,6 +70,56 @@ def test_smoke(): dm = SemanticSegmentationData(batch_size=1) assert dm is not None + @staticmethod + @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") + def test_identity(tmpdir): + class IdentityTransform(InputTransform): + def per_sample_transform(self) -> Callable: + return ApplyToKeys( + DataKeys.INPUT, + np.array, + ) + + def per_batch_transform(self) -> Callable: + return lambda x: x + + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [str(tmp_dir / "images" / "img1.png")] + + targets = [str(tmp_dir / "targets" / "img1.png")] + + num_classes: int = 2 + img_size: Tuple[int, int] = (128, 128) + images_data, targets_data = create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_files( + test_files=images, + test_targets=targets, + batch_size=1, + num_workers=0, + num_classes=num_classes, + transform=IdentityTransform(), + ) + + assert dm is not None + assert dm.test_dataloader() is not None + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DataKeys.INPUT], data[DataKeys.TARGET] + assert imgs.shape == (1, 128, 128, 3) + assert labels.shape == (1, 128, 128) + assert torch.allclose(imgs, torch.from_numpy(np.array(images_data[0]))) + assert torch.allclose(labels, torch.from_numpy(np.array(targets_data[0]))[:, :, 0]) + @staticmethod def test_from_folders(tmpdir): tmp_dir = Path(tmpdir) From 2780f4f3e699bbb3a20761110a09f7335916a410 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 17:09:00 +0200 Subject: [PATCH 31/39] bump min py3.8 (#1552) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/ci-testing.yml | 2 +- .github/workflows/docs-check.yml | 4 +- .github/workflows/pypi-release.yml | 2 +- .pre-commit-config.yaml | 2 +- setup.py | 4 +- .../core/serve/_compat/cached_property.py | 70 +------------------ src/flash/core/utilities/stability.py | 2 +- tests/core/utilities/test_lightning_cli.py | 5 -- 8 files changed, 10 insertions(+), 81 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 80e6b2dee4..ed39516f73 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -41,7 +41,7 @@ jobs: - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'serve', extra: [] } - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'graph', extra: [] } - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'audio', extra: [] } - - { os: 'ubuntu-20.04', python-version: 3.7, topic: 'core', extra: [], requires: 'oldest' } + - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'core', extra: [], requires: 'oldest' } # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 50 diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 46a7c910b5..c4c4e26d07 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -19,7 +19,7 @@ jobs: submodules: true - uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.8 # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow @@ -60,7 +60,7 @@ jobs: submodules: true - uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.8 # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 3e71d514e1..5736e12f29 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -17,7 +17,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: 3.7 + python-version: 3.8 - name: Install dependencies run: pip install --user --upgrade setuptools wheel build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e63f80f90..549f5e8e42 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,7 @@ repos: rev: v3.3.2 hooks: - id: pyupgrade - args: [--py37-plus] + args: [--py38-plus] name: Upgrade code - repo: https://github.com/kynan/nbstripout diff --git a/setup.py b/setup.py index 88267df0c9..56074f92e5 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,7 @@ def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict: }, zip_safe=False, keywords=["deep learning", "pytorch", "AI"], - python_requires=">=3.7", + python_requires=">=3.8", install_requires=_load_requirements(path_dir=_PATH_ROOT, file_name="requirements.txt"), extras_require=_get_extras(), project_urls={ @@ -195,8 +195,8 @@ def _get_extras(path_dir: str = _PATH_REQUIRE) -> dict: # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], ) diff --git a/src/flash/core/serve/_compat/cached_property.py b/src/flash/core/serve/_compat/cached_property.py index 2adde68103..50327f8d3f 100644 --- a/src/flash/core/serve/_compat/cached_property.py +++ b/src/flash/core/serve/_compat/cached_property.py @@ -8,72 +8,6 @@ __all__ = ("cached_property",) # Standard Library -from sys import version_info +from functools import cached_property # pylint: disable=no-name-in-module -if version_info >= (3, 8): - # Standard Library - from functools import cached_property # pylint: disable=no-name-in-module -else: - # Standard Library - from threading import RLock - from typing import Any, Callable, Optional, Type, TypeVar - - _NOT_FOUND = object() - _T = TypeVar("_T") - _S = TypeVar("_S") - - # noinspection PyPep8Naming - class cached_property: # NOSONAR # pylint: disable=invalid-name # noqa: N801 - """Cached property implementation. - - Transform a method of a class into a property whose value is computed once and then cached as a normal attribute - for the life of the instance. Similar to property(), with the addition of caching. Useful for expensive computed - properties of instances that are otherwise effectively immutable. - """ - - def __init__(self, func: Callable[[Any], _T]) -> None: - """Cached property implementation.""" - self.func = func - self.attrname: Optional[str] = None - self.__doc__ = func.__doc__ - self.lock = RLock() - - def __set_name__(self, owner: Type[Any], name: str) -> None: - """Assign attribute name and owner.""" - if self.attrname is None: - self.attrname = name - elif name != self.attrname: - raise TypeError( - "Cannot assign the same cached_property to two different names " - f"({self.attrname!r} and {name!r})." - ) - - def __get__(self, instance, owner=None) -> Any: - if instance is None: - return self - if self.attrname is None: - raise TypeError("Cannot use cached_property instance without calling __set_name__ on it.") - try: - cache = instance.__dict__ - except AttributeError: # not all objects have __dict__ (e.g. class defines slots) - msg = ( - f"No '__dict__' attribute on {type(instance).__name__!r} " - f"instance to cache {self.attrname!r} property." - ) - raise TypeError(msg) from None - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - with self.lock: - # check if another thread filled cache while we awaited lock - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - val = self.func(instance) - try: - cache[self.attrname] = val - except TypeError: - msg = ( - f"The '__dict__' attribute on {type(instance).__name__!r} instance " - f"does not support item assignment for caching {self.attrname!r} property." - ) - raise TypeError(msg) from None - return val +# Standard Library diff --git a/src/flash/core/utilities/stability.py b/src/flash/core/utilities/stability.py index 16e29a0d7d..8dd553a8a8 100644 --- a/src/flash/core/utilities/stability.py +++ b/src/flash/core/utilities/stability.py @@ -24,7 +24,7 @@ __doctest_skip__ = ["beta"] -@functools.lru_cache() # Trick to only warn once for each message +@functools.lru_cache # Trick to only warn once for each message def _raise_beta_warning(message: str, stacklevel: int = 6): rank_zero_warn( f"{message} The API and functionality may change without warning in future releases. " diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index 5602c32621..bc7e6e9b3b 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -4,7 +4,6 @@ import json import os import pickle -import sys from argparse import Namespace from contextlib import redirect_stdout from io import StringIO @@ -148,10 +147,6 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.skipif( - sys.version_info < (3, 7), - reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", -) @pytest.mark.parametrize( ["cli_args", "extra_args"], [ From e55bfa1062de429fbf07d77ca0dc98ff407a102a Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 11 May 2023 17:10:35 +0200 Subject: [PATCH 32/39] codecov: informational --- .codecov.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.codecov.yml b/.codecov.yml index bdb5074d4f..646217fb2d 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -20,6 +20,7 @@ coverage: # https://codecov.readme.io/v1.0/docs/commit-status project: default: + informational: true against: auto target: 99% # specify the target coverage for each commit status threshold: 30% # allow this little decrease on project @@ -29,6 +30,7 @@ coverage: # https://github.com/codecov/support/wiki/Patch-Status patch: default: + informational: true against: auto target: 50% # specify the target "X%" coverage to hit # threshold: 50% # allow this much decrease on patch From fd7987abfd0d2bb73fd99a434b7503cca883f7c8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 17:14:33 +0200 Subject: [PATCH 33/39] Delete config.yaml --- config.yaml | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 config.yaml diff --git a/config.yaml b/config.yaml deleted file mode 100644 index e69de29bb2..0000000000 From c3dceb36b01b63cce070482177f552f91a74057e Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 18:15:20 +0200 Subject: [PATCH 34/39] ruff: enable C4 (#1553) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/audio/audio_classification.py | 2 +- examples/image/semantic_segmentation.py | 2 +- pyproject.toml | 3 + src/flash/__main__.py | 8 +- src/flash/audio/classification/data.py | 52 +++++------ src/flash/audio/speech_recognition/data.py | 30 +++---- src/flash/core/data/io/input.py | 2 +- src/flash/core/data/splits.py | 4 +- src/flash/core/data/utils.py | 4 +- .../core/integrations/labelstudio/input.py | 2 +- src/flash/core/optimizers/lamb.py | 16 ++-- src/flash/core/optimizers/lars.py | 8 +- src/flash/core/serve/dag/optimization.py | 4 +- src/flash/graph/classification/data.py | 6 +- src/flash/image/classification/data.py | 52 +++++------ src/flash/image/detection/cli.py | 2 +- src/flash/image/detection/data.py | 28 +++--- src/flash/image/face_detection/data.py | 2 +- src/flash/image/instance_segmentation/cli.py | 2 +- src/flash/image/instance_segmentation/data.py | 2 +- src/flash/image/keypoint_detection/cli.py | 2 +- src/flash/image/keypoint_detection/data.py | 2 +- src/flash/image/segmentation/data.py | 32 +++---- src/flash/pointcloud/detection/data.py | 26 +++--- src/flash/pointcloud/segmentation/data.py | 6 +- src/flash/tabular/classification/data.py | 56 ++++++------ src/flash/tabular/input.py | 14 +-- src/flash/tabular/regression/data.py | 48 +++++------ src/flash/template/classification/data.py | 4 +- src/flash/text/classification/data.py | 60 ++++++------- src/flash/text/question_answering/data.py | 42 ++++----- src/flash/text/question_answering/input.py | 4 +- src/flash/text/seq2seq/core/input.py | 2 +- src/flash/text/seq2seq/summarization/data.py | 28 +++--- src/flash/text/seq2seq/translation/data.py | 28 +++--- src/flash/text/seq2seq/translation/model.py | 2 +- src/flash/video/classification/data.py | 86 +++++++++---------- tests/audio/classification/test_data.py | 6 +- tests/audio/speech_recognition/test_model.py | 2 +- tests/core/serve/test_components.py | 4 +- .../core/serve/test_dag/test_optimization.py | 2 +- tests/core/utilities/test_lightning_cli.py | 80 ++++++++--------- .../classification/test_active_learning.py | 2 +- tests/image/classification/test_data.py | 14 +-- .../test_training_strategies.py | 2 +- tests/image/detection/test_data.py | 12 +-- tests/image/embedding/test_model.py | 8 +- tests/image/instance_segm/test_data.py | 4 +- tests/image/instance_segm/test_model.py | 2 +- tests/image/keypoint_detection/test_data.py | 4 +- tests/image/keypoint_detection/test_model.py | 2 +- tests/text/classification/test_data.py | 42 ++++----- 52 files changed, 435 insertions(+), 424 deletions(-) diff --git a/examples/audio/audio_classification.py b/examples/audio/audio_classification.py index 51415b9e94..2fc70a21aa 100644 --- a/examples/audio/audio_classification.py +++ b/examples/audio/audio_classification.py @@ -24,7 +24,7 @@ datamodule = AudioClassificationData.from_folders( train_folder="data/urban8k_images/train", val_folder="data/urban8k_images/val", - transform_kwargs=dict(spectrogram_size=(64, 64)), + transform_kwargs={"spectrogram_size": (64, 64)}, batch_size=4, ) diff --git a/examples/image/semantic_segmentation.py b/examples/image/semantic_segmentation.py index 39e5d26d2d..1b1a93cf93 100644 --- a/examples/image/semantic_segmentation.py +++ b/examples/image/semantic_segmentation.py @@ -29,7 +29,7 @@ train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", val_split=0.1, - transform_kwargs=dict(image_size=(256, 256)), + transform_kwargs={"image_size": (256, 256)}, num_classes=21, batch_size=4, ) diff --git a/pyproject.toml b/pyproject.toml index 38ebddb0b7..79a4f52855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ select = [ # "D", # see: https://pypi.org/project/pydocstyle # "N", # see: https://pypi.org/project/pep8-naming ] +extend-select = [ + "C4", # see: https://pypi.org/project/flake8-comprehensions +] ignore = [ "E731", # Do not assign a lambda expression, use a def ] diff --git a/src/flash/__main__.py b/src/flash/__main__.py index 1f521bb2a8..b7f0615cfe 100644 --- a/src/flash/__main__.py +++ b/src/flash/__main__.py @@ -26,10 +26,10 @@ def main(): def register_command(command): @main.command( command.__name__, - context_settings=dict( - help_option_names=[], - ignore_unknown_options=True, - ), + context_settings={ + "help_option_names": [], + "ignore_unknown_options": True, + }, ) @click.argument("cli_args", nargs=-1, type=click.UNPROCESSED) @functools.wraps(command) diff --git a/src/flash/audio/classification/data.py b/src/flash/audio/classification/data.py index 721c0fa36b..951222b033 100644 --- a/src/flash/audio/classification/data.py +++ b/src/flash/audio/classification/data.py @@ -143,11 +143,11 @@ def from_files( >>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( - sampling_rate=sampling_rate, - n_fft=n_fft, - target_formatter=target_formatter, - ) + ds_kw = { + "sampling_rate": sampling_rate, + "n_fft": n_fft, + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -277,11 +277,11 @@ def from_folders( >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( - sampling_rate=sampling_rate, - n_fft=n_fft, - target_formatter=target_formatter, - ) + ds_kw = { + "sampling_rate": sampling_rate, + "n_fft": n_fft, + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_folder, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -367,9 +367,9 @@ def from_numpy( Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -455,9 +455,9 @@ def from_tensors( Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -609,11 +609,11 @@ def from_data_frame( >>> del predict_data_frame """ - ds_kw = dict( - sampling_rate=sampling_rate, - n_fft=n_fft, - target_formatter=target_formatter, - ) + ds_kw = { + "sampling_rate": sampling_rate, + "n_fft": n_fft, + "target_formatter": target_formatter, + } train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver) @@ -856,11 +856,11 @@ def from_csv( >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - sampling_rate=sampling_rate, - n_fft=n_fft, - target_formatter=target_formatter, - ) + ds_kw = { + "sampling_rate": sampling_rate, + "n_fft": n_fft, + "target_formatter": target_formatter, + } train_data = (train_file, input_field, target_fields, train_images_root, train_resolver) val_data = (val_file, input_field, target_fields, val_images_root, val_resolver) diff --git a/src/flash/audio/speech_recognition/data.py b/src/flash/audio/speech_recognition/data.py index 86127c41e6..b205afcd09 100644 --- a/src/flash/audio/speech_recognition/data.py +++ b/src/flash/audio/speech_recognition/data.py @@ -119,9 +119,9 @@ def from_files( >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] """ - ds_kw = dict( - sampling_rate=sampling_rate, - ) + ds_kw = { + "sampling_rate": sampling_rate, + } return cls( input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw), @@ -306,10 +306,10 @@ def from_csv( >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - input_key=input_field, - sampling_rate=sampling_rate, - ) + ds_kw = { + "input_key": input_field, + "sampling_rate": sampling_rate, + } return cls( input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw), @@ -430,11 +430,11 @@ def from_json( >>> os.remove("predict_data.json") """ - ds_kw = dict( - input_key=input_field, - sampling_rate=sampling_rate, - field=field, - ) + ds_kw = { + "input_key": input_field, + "sampling_rate": sampling_rate, + "field": field, + } return cls( input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw), @@ -580,9 +580,9 @@ def from_datasets( >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] """ - ds_kw = dict( - sampling_rate=sampling_rate, - ) + ds_kw = { + "sampling_rate": sampling_rate, + } return cls( input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), diff --git a/src/flash/core/data/io/input.py b/src/flash/core/data/io/input.py index f8e3bdc96e..6e1491d44e 100644 --- a/src/flash/core/data/io/input.py +++ b/src/flash/core/data/io/input.py @@ -38,7 +38,7 @@ def _deepcopy_dict(nested_dict: Any) -> Any: """Utility to deepcopy a nested dict.""" if not isinstance(nested_dict, Dict): return nested_dict - return {key: value for key, value in nested_dict.items()} + return dict(nested_dict.items()) class InputFormat(LightningEnum): diff --git a/src/flash/core/data/splits.py b/src/flash/core/data/splits.py index f15ce6ca64..8862b70152 100644 --- a/src/flash/core/data/splits.py +++ b/src/flash/core/data/splits.py @@ -32,9 +32,9 @@ def __init__( ) -> None: kwargs = {} if running_stage is not None: - kwargs = dict(running_stage=running_stage) + kwargs = {"running_stage": running_stage} elif isinstance(dataset, Properties): - kwargs = dict(running_stage=dataset._running_stage) + kwargs = {"running_stage": dataset._running_stage} super().__init__(**kwargs) if not isinstance(indices, list): diff --git a/src/flash/core/data/utils.py b/src/flash/core/data/utils.py index 3c420f682a..fb85435f45 100644 --- a/src/flash/core/data/utils.py +++ b/src/flash/core/data/utils.py @@ -88,8 +88,8 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: chunk_size = 1024 num_bars = int(file_size / chunk_size) if verbose: - print(dict(file_size=file_size)) - print(dict(num_bars=num_bars)) + print({"file_size": file_size}) + print({"num_bars": num_bars}) if not os.path.exists(local_filename): with open(local_filename, "wb") as fp: diff --git a/src/flash/core/integrations/labelstudio/input.py b/src/flash/core/integrations/labelstudio/input.py index 6d2ee49374..723bf857e8 100644 --- a/src/flash/core/integrations/labelstudio/input.py +++ b/src/flash/core/integrations/labelstudio/input.py @@ -33,7 +33,7 @@ class LabelStudioParameters: def _get_labels_from_sample(labels, classes): """Translate string labels to int.""" - sorted_labels = sorted(list(classes)) + sorted_labels = sorted(classes) return [sorted_labels.index(item) for item in labels] if isinstance(labels, list) else sorted_labels.index(labels) diff --git a/src/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py index 53c1fb7038..853d3f8668 100644 --- a/src/flash/core/optimizers/lamb.py +++ b/src/flash/core/optimizers/lamb.py @@ -82,14 +82,14 @@ def __init__( raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - exclude_from_layer_adaptation=exclude_from_layer_adaptation, - amsgrad=amsgrad, - ) + defaults = { + "lr": lr, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "exclude_from_layer_adaptation": exclude_from_layer_adaptation, + "amsgrad": amsgrad, + } super().__init__(params, defaults) def __setstate__(self, state): diff --git a/src/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py index fd3334e0e5..e1f0e43659 100644 --- a/src/flash/core/optimizers/lars.py +++ b/src/flash/core/optimizers/lars.py @@ -92,7 +92,13 @@ def __init__( if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) + defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + "nesterov": nesterov, + } if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index 2909d54e85..7b84d264d7 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -36,7 +36,7 @@ def cull(dsk, keys): keys = [keys] seen = set() - dependencies = dict() + dependencies = {} out = {} work = list(set(flatten(keys))) @@ -900,4 +900,4 @@ def __reduce__(self): return SubgraphCallable, (self.dsk, self.outkey, self.inkeys, self.name) def __hash__(self): - return hash(tuple((self.outkey, tuple(self.inkeys), self.name))) + return hash((self.outkey, tuple(self.inkeys), self.name)) diff --git a/src/flash/graph/classification/data.py b/src/flash/graph/classification/data.py index b0c909b53d..457d7b45c2 100644 --- a/src/flash/graph/classification/data.py +++ b/src/flash/graph/classification/data.py @@ -166,9 +166,9 @@ def from_datasets( Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_dataset, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) diff --git a/src/flash/image/classification/data.py b/src/flash/image/classification/data.py index 55356bfea6..e57065fa4a 100644 --- a/src/flash/image/classification/data.py +++ b/src/flash/image/classification/data.py @@ -161,9 +161,9 @@ def from_files( >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -281,9 +281,9 @@ def from_folders( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_folder, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -367,9 +367,9 @@ def from_numpy( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -458,9 +458,9 @@ def from_images( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_images, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -544,9 +544,9 @@ def from_tensors( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -683,9 +683,9 @@ def from_data_frame( >>> del train_data_frame >>> del predict_data_frame """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver) @@ -914,9 +914,9 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_data = (train_file, input_field, target_fields, train_images_root, train_resolver) val_data = (val_file, input_field, target_fields, val_images_root, val_resolver) @@ -1033,9 +1033,9 @@ def from_fiftyone( >>> del train_dataset >>> del predict_dataset """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_dataset, label_field=label_field, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -1121,7 +1121,7 @@ def from_labelstudio( multi_label=multi_label, ) - ds_kw = dict() + ds_kw = {} train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) ds_kw["parameters"] = getattr(train_input, "parameters", None) @@ -1172,7 +1172,7 @@ def from_datasets( train_dataset=train_dataset, ) """ - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), diff --git a/src/flash/image/detection/cli.py b/src/flash/image/detection/cli.py index 3fa6a8a05b..233ba979e5 100644 --- a/src/flash/image/detection/cli.py +++ b/src/flash/image/detection/cli.py @@ -32,7 +32,7 @@ def from_coco_128( train_folder="data/coco128/images/train2017/", train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=val_split, - transform_kwargs=dict(image_size=(128, 128)) if transform_kwargs is None else transform_kwargs, + transform_kwargs={"image_size": (128, 128)} if transform_kwargs is None else transform_kwargs, batch_size=batch_size, **data_module_kwargs, ) diff --git a/src/flash/image/detection/data.py b/src/flash/image/detection/data.py index aafb8ce99d..2f5bd3f0cc 100644 --- a/src/flash/image/detection/data.py +++ b/src/flash/image/detection/data.py @@ -163,9 +163,9 @@ def from_files( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls( RunningStage.TRAINING, @@ -281,9 +281,9 @@ def from_numpy( Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls( RunningStage.TRAINING, @@ -404,9 +404,9 @@ def from_images( Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls( RunningStage.TRAINING, @@ -522,9 +522,9 @@ def from_tensors( Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls( RunningStage.TRAINING, @@ -576,7 +576,7 @@ def from_icedata( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "ObjectDetectionData": - ds_kw = dict(parser=parser) + ds_kw = {"parser": parser} return cls( input_cls( @@ -1162,7 +1162,7 @@ def from_fiftyone( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_dataset, label_field, iscrowd, **ds_kw), diff --git a/src/flash/image/face_detection/data.py b/src/flash/image/face_detection/data.py index 9bbc6e2aed..79321ff473 100644 --- a/src/flash/image/face_detection/data.py +++ b/src/flash/image/face_detection/data.py @@ -41,7 +41,7 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "FaceDetectionData": - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), diff --git a/src/flash/image/instance_segmentation/cli.py b/src/flash/image/instance_segmentation/cli.py index 62cf9c838e..98fd285f4f 100644 --- a/src/flash/image/instance_segmentation/cli.py +++ b/src/flash/image/instance_segmentation/cli.py @@ -53,7 +53,7 @@ def from_pets( test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, - transform_kwargs=dict(image_size=(128, 128)) if transform_kwargs is None else transform_kwargs, + transform_kwargs={"image_size": (128, 128)} if transform_kwargs is None else transform_kwargs, parser=parser, val_split=val_split, batch_size=batch_size, diff --git a/src/flash/image/instance_segmentation/data.py b/src/flash/image/instance_segmentation/data.py index 9676a629ff..0f9430e42a 100644 --- a/src/flash/image/instance_segmentation/data.py +++ b/src/flash/image/instance_segmentation/data.py @@ -85,7 +85,7 @@ def from_icedata( transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "InstanceSegmentationData": - ds_kw = dict(parser=parser) + ds_kw = {"parser": parser} return cls( input_cls( diff --git a/src/flash/image/keypoint_detection/cli.py b/src/flash/image/keypoint_detection/cli.py index 67b4154620..8cf4eaeade 100644 --- a/src/flash/image/keypoint_detection/cli.py +++ b/src/flash/image/keypoint_detection/cli.py @@ -53,7 +53,7 @@ def from_biwi( test_ann_file=test_ann_file, predict_folder=predict_folder, val_split=val_split, - transform_kwargs=dict(image_size=(128, 128)) if transform_kwargs is None else transform_kwargs, + transform_kwargs={"image_size": (128, 128)} if transform_kwargs is None else transform_kwargs, batch_size=batch_size, parser=parser, **data_module_kwargs, diff --git a/src/flash/image/keypoint_detection/data.py b/src/flash/image/keypoint_detection/data.py index f4324359fa..98fb0f2078 100644 --- a/src/flash/image/keypoint_detection/data.py +++ b/src/flash/image/keypoint_detection/data.py @@ -85,7 +85,7 @@ def from_icedata( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "KeypointDetectionData": - ds_kw = dict(parser=parser) + ds_kw = {"parser": parser} return cls( input_cls( diff --git a/src/flash/image/segmentation/data.py b/src/flash/image/segmentation/data.py index 815d041b19..363f5a6df4 100644 --- a/src/flash/image/segmentation/data.py +++ b/src/flash/image/segmentation/data.py @@ -148,10 +148,10 @@ def from_files( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict( - num_classes=num_classes, - labels_map=labels_map, - ) + ds_kw = { + "num_classes": num_classes, + "labels_map": labels_map, + } return cls( input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw), @@ -289,10 +289,10 @@ def from_folders( >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( - num_classes=num_classes, - labels_map=labels_map, - ) + ds_kw = { + "num_classes": num_classes, + "labels_map": labels_map, + } return cls( input_cls(RunningStage.TRAINING, train_folder, train_target_folder, **ds_kw), @@ -377,10 +377,10 @@ def from_numpy( Predicting... """ - ds_kw = dict( - num_classes=num_classes, - labels_map=labels_map, - ) + ds_kw = { + "num_classes": num_classes, + "labels_map": labels_map, + } return cls( input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), @@ -465,10 +465,10 @@ def from_tensors( Predicting... """ - ds_kw = dict( - num_classes=num_classes, - labels_map=labels_map, - ) + ds_kw = { + "num_classes": num_classes, + "labels_map": labels_map, + } return cls( input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), diff --git a/src/flash/pointcloud/detection/data.py b/src/flash/pointcloud/detection/data.py index a81049c33b..b69470fb08 100644 --- a/src/flash/pointcloud/detection/data.py +++ b/src/flash/pointcloud/detection/data.py @@ -48,12 +48,12 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict( - scans_folder_name=scans_folder_name, - labels_folder_name=labels_folder_name, - calibrations_folder_name=calibrations_folder_name, - data_format=data_format, - ) + ds_kw = { + "scans_folder_name": scans_folder_name, + "labels_folder_name": labels_folder_name, + "calibrations_folder_name": calibrations_folder_name, + "data_format": data_format, + } return cls( input_cls(RunningStage.TRAINING, train_folder, **ds_kw), @@ -78,12 +78,12 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict( - scans_folder_name=scans_folder_name, - labels_folder_name=labels_folder_name, - calibrations_folder_name=calibrations_folder_name, - data_format=data_format, - ) + ds_kw = { + "scans_folder_name": scans_folder_name, + "labels_folder_name": labels_folder_name, + "calibrations_folder_name": calibrations_folder_name, + "data_format": data_format, + } return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), @@ -104,7 +104,7 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudObjectDetectorData": - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), diff --git a/src/flash/pointcloud/segmentation/data.py b/src/flash/pointcloud/segmentation/data.py index f07f1981e1..27522d6927 100644 --- a/src/flash/pointcloud/segmentation/data.py +++ b/src/flash/pointcloud/segmentation/data.py @@ -40,7 +40,7 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_folder, **ds_kw), @@ -61,7 +61,7 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict() + ds_kw = {} return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), @@ -82,7 +82,7 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "PointCloudSegmentationData": - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_dataset, **ds_kw), diff --git a/src/flash/tabular/classification/data.py b/src/flash/tabular/classification/data.py index 464ffe59dc..9e5f89a6b8 100644 --- a/src/flash/tabular/classification/data.py +++ b/src/flash/tabular/classification/data.py @@ -157,13 +157,13 @@ def from_data_frame( >>> del train_data >>> del predict_data """ - ds_kw = dict( - target_formatter=target_formatter, - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - parameters=parameters, - ) + ds_kw = { + "target_formatter": target_formatter, + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_fields": target_fields, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters @@ -359,13 +359,13 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - target_formatter=target_formatter, - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - parameters=parameters, - ) + ds_kw = { + "target_formatter": target_formatter, + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_fields": target_fields, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters @@ -495,13 +495,13 @@ def from_dicts( >>> del train_data >>> del predict_data """ - ds_kw = dict( - target_formatter=target_formatter, - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - parameters=parameters, - ) + ds_kw = { + "target_formatter": target_formatter, + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_fields": target_fields, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters @@ -633,13 +633,13 @@ def from_lists( >>> del train_data >>> del predict_data """ - ds_kw = dict( - target_formatter=target_formatter, - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - parameters=parameters, - ) + ds_kw = { + "target_formatter": target_formatter, + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_fields": target_fields, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_list, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters diff --git a/src/flash/tabular/input.py b/src/flash/tabular/input.py index 26e315c019..7210839342 100644 --- a/src/flash/tabular/input.py +++ b/src/flash/tabular/input.py @@ -64,13 +64,13 @@ def compute_parameters( codes = _generate_codes(train_data_frame, categorical_fields) - return dict( - mean=mean, - std=std, - codes=codes, - numerical_fields=numerical_fields, - categorical_fields=categorical_fields, - ) + return { + "mean": mean, + "std": std, + "codes": codes, + "numerical_fields": numerical_fields, + "categorical_fields": categorical_fields, + } def preprocess( self, diff --git a/src/flash/tabular/regression/data.py b/src/flash/tabular/regression/data.py index 1676f6aef6..9e1dd0d01c 100644 --- a/src/flash/tabular/regression/data.py +++ b/src/flash/tabular/regression/data.py @@ -147,12 +147,12 @@ def from_data_frame( >>> del train_data >>> del predict_data """ - ds_kw = dict( - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_field=target_field, - parameters=parameters, - ) + ds_kw = { + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_field": target_field, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters @@ -334,12 +334,12 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_field=target_field, - parameters=parameters, - ) + ds_kw = { + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_field": target_field, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters @@ -460,12 +460,12 @@ def from_dicts( >>> del train_data >>> del predict_data """ - ds_kw = dict( - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_field=target_field, - parameters=parameters, - ) + ds_kw = { + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_field": target_field, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_dict, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters @@ -589,12 +589,12 @@ def from_lists( >>> del train_data >>> del predict_data """ - ds_kw = dict( - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_field=target_field, - parameters=parameters, - ) + ds_kw = { + "categorical_fields": categorical_fields, + "numerical_fields": numerical_fields, + "target_field": target_field, + "parameters": parameters, + } train_input = input_cls(RunningStage.TRAINING, train_list, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters diff --git a/src/flash/template/classification/data.py b/src/flash/template/classification/data.py index 42ec03f98c..ce0318f758 100644 --- a/src/flash/template/classification/data.py +++ b/src/flash/template/classification/data.py @@ -157,7 +157,7 @@ def from_numpy( The constructed data module. """ - ds_kw = dict() + ds_kw = {} train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) target_formatter = getattr(train_input, "target_formatter", None) @@ -201,7 +201,7 @@ def from_sklearn( Returns: The constructed data module. """ - ds_kw = dict() + ds_kw = {} train_input = input_cls(RunningStage.TRAINING, train_bunch, **ds_kw) target_formatter = getattr(train_input, "target_formatter", None) diff --git a/src/flash/text/classification/data.py b/src/flash/text/classification/data.py index 05e2a5c42b..a5ef47cdb3 100644 --- a/src/flash/text/classification/data.py +++ b/src/flash/text/classification/data.py @@ -211,11 +211,11 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - target_formatter=target_formatter, - input_key=input_field, - target_keys=target_fields, - ) + ds_kw = { + "target_formatter": target_formatter, + "input_key": input_field, + "target_keys": target_fields, + } train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -333,12 +333,12 @@ def from_json( >>> os.remove("train_data.json") >>> os.remove("predict_data.json") """ - ds_kw = dict( - target_formatter=target_formatter, - input_key=input_field, - target_keys=target_fields, - field=field, - ) + ds_kw = { + "target_formatter": target_formatter, + "input_key": input_field, + "target_keys": target_fields, + "field": field, + } train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -456,11 +456,11 @@ def from_parquet( >>> os.remove("train_data.parquet") >>> os.remove("predict_data.parquet") """ - ds_kw = dict( - target_formatter=target_formatter, - input_key=input_field, - target_keys=target_fields, - ) + ds_kw = { + "target_formatter": target_formatter, + "input_key": input_field, + "target_keys": target_fields, + } train_input = input_cls(RunningStage.TRAINING, train_file, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -559,11 +559,11 @@ def from_hf_datasets( >>> del train_data >>> del predict_data """ - ds_kw = dict( - target_formatter=target_formatter, - input_key=input_field, - target_keys=target_fields, - ) + ds_kw = { + "target_formatter": target_formatter, + "input_key": input_field, + "target_keys": target_fields, + } train_input = input_cls(RunningStage.TRAINING, train_hf_dataset, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -663,11 +663,11 @@ def from_data_frame( >>> del train_data >>> del predict_data """ - ds_kw = dict( - target_formatter=target_formatter, - input_key=input_field, - target_keys=target_fields, - ) + ds_kw = { + "target_formatter": target_formatter, + "input_key": input_field, + "target_keys": target_fields, + } train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -749,9 +749,9 @@ def from_lists( >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ - ds_kw = dict( - target_formatter=target_formatter, - ) + ds_kw = { + "target_formatter": target_formatter, + } train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) @@ -829,7 +829,7 @@ def from_labelstudio( multi_label=multi_label, ) - ds_kw = dict() + ds_kw = {} train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) ds_kw["parameters"] = getattr(train_input, "parameters", None) diff --git a/src/flash/text/question_answering/data.py b/src/flash/text/question_answering/data.py index f87a41faf8..77c3af4d21 100644 --- a/src/flash/text/question_answering/data.py +++ b/src/flash/text/question_answering/data.py @@ -209,11 +209,11 @@ def from_csv( >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - question_column_name=question_column_name, - context_column_name=context_column_name, - answer_column_name=answer_column_name, - ) + ds_kw = { + "question_column_name": question_column_name, + "context_column_name": context_column_name, + "answer_column_name": answer_column_name, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -353,12 +353,12 @@ def from_json( >>> os.remove("predict_data.json") """ - ds_kw = dict( - field=field, - question_column_name=question_column_name, - context_column_name=context_column_name, - answer_column_name=answer_column_name, - ) + ds_kw = { + "field": field, + "question_column_name": question_column_name, + "context_column_name": context_column_name, + "answer_column_name": answer_column_name, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -635,11 +635,11 @@ def from_squad_v2( >>> os.remove("predict_data.json") """ - ds_kw = dict( - question_column_name=question_column_name, - context_column_name=context_column_name, - answer_column_name=answer_column_name, - ) + ds_kw = { + "question_column_name": question_column_name, + "context_column_name": context_column_name, + "answer_column_name": answer_column_name, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -748,11 +748,11 @@ def from_dicts( >>> del predict_data """ - ds_kw = dict( - question_column_name=question_column_name, - context_column_name=context_column_name, - answer_column_name=answer_column_name, - ) + ds_kw = { + "question_column_name": question_column_name, + "context_column_name": context_column_name, + "answer_column_name": answer_column_name, + } return cls( input_cls(RunningStage.TRAINING, train_data, **ds_kw), diff --git a/src/flash/text/question_answering/input.py b/src/flash/text/question_answering/input.py index 8517ea9c14..b587a3346b 100644 --- a/src/flash/text/question_answering/input.py +++ b/src/flash/text/question_answering/input.py @@ -78,7 +78,7 @@ def load_data( if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset - hf_dataset = [sample for sample in hf_dataset][:40] + hf_dataset = list(hf_dataset)[:40] return hf_dataset @@ -166,7 +166,7 @@ def load_data( if not self.predicting: _answer_starts = [answer["answer_start"] for answer in qa["answers"]] _answers = [answer["text"] for answer in qa["answers"]] - answers.append(dict(text=_answers, answer_start=_answer_starts)) + answers.append({"text": _answers, "answer_start": _answer_starts}) data = {"id": ids, "title": titles, "context": contexts, "question": questions} if not self.predicting: diff --git a/src/flash/text/seq2seq/core/input.py b/src/flash/text/seq2seq/core/input.py index 857aa75bdb..e8d939a717 100644 --- a/src/flash/text/seq2seq/core/input.py +++ b/src/flash/text/seq2seq/core/input.py @@ -45,7 +45,7 @@ def load_data( if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset - hf_dataset = [sample for sample in hf_dataset][:40] + hf_dataset = list(hf_dataset)[:40] return hf_dataset diff --git a/src/flash/text/seq2seq/summarization/data.py b/src/flash/text/seq2seq/summarization/data.py index 8f7c2899d7..ca8c21eb19 100644 --- a/src/flash/text/seq2seq/summarization/data.py +++ b/src/flash/text/seq2seq/summarization/data.py @@ -186,10 +186,10 @@ def from_csv( >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - input_key=input_field, - target_key=target_field, - ) + ds_kw = { + "input_key": input_field, + "target_key": target_field, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -295,11 +295,11 @@ def from_json( >>> os.remove("predict_data.json") """ - ds_kw = dict( - input_key=input_field, - target_key=target_field, - field=field, - ) + ds_kw = { + "input_key": input_field, + "target_key": target_field, + "field": field, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -388,10 +388,10 @@ def from_hf_datasets( >>> del predict_data """ - ds_kw = dict( - input_key=input_field, - target_key=target_field, - ) + ds_kw = { + "input_key": input_field, + "target_key": target_field, + } return cls( input_cls(RunningStage.TRAINING, train_hf_dataset, **ds_kw), @@ -462,7 +462,7 @@ def from_lists( Predicting... """ - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), diff --git a/src/flash/text/seq2seq/translation/data.py b/src/flash/text/seq2seq/translation/data.py index e5c8af1c0a..12dced0bec 100644 --- a/src/flash/text/seq2seq/translation/data.py +++ b/src/flash/text/seq2seq/translation/data.py @@ -184,10 +184,10 @@ def from_csv( >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - input_key=input_field, - target_key=target_field, - ) + ds_kw = { + "input_key": input_field, + "target_key": target_field, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -292,11 +292,11 @@ def from_json( >>> os.remove("predict_data.json") """ - ds_kw = dict( - input_key=input_field, - target_key=target_field, - field=field, - ) + ds_kw = { + "input_key": input_field, + "target_key": target_field, + "field": field, + } return cls( input_cls(RunningStage.TRAINING, train_file, **ds_kw), @@ -385,10 +385,10 @@ def from_hf_datasets( >>> del predict_data """ - ds_kw = dict( - input_key=input_field, - target_key=target_field, - ) + ds_kw = { + "input_key": input_field, + "target_key": target_field, + } return cls( input_cls(RunningStage.TRAINING, train_hf_dataset, **ds_kw), @@ -459,7 +459,7 @@ def from_lists( Predicting... """ - ds_kw = dict() + ds_kw = {} return cls( input_cls(RunningStage.TRAINING, train_data, train_targets, **ds_kw), diff --git a/src/flash/text/seq2seq/translation/model.py b/src/flash/text/seq2seq/translation/model.py index 2be597e052..71fe3834aa 100644 --- a/src/flash/text/seq2seq/translation/model.py +++ b/src/flash/text/seq2seq/translation/model.py @@ -91,7 +91,7 @@ def compute_metrics(self, generated_tokens, batch, prefix): reference_corpus = [[reference] for reference in reference_corpus] translate_corpus = self.decode(generated_tokens) - translate_corpus = [line for line in translate_corpus] + translate_corpus = list(translate_corpus) result = self.bleu(translate_corpus, reference_corpus) self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True) diff --git a/src/flash/video/classification/data.py b/src/flash/video/classification/data.py index 15065c643b..531b2c6615 100644 --- a/src/flash/video/classification/data.py +++ b/src/flash/video/classification/data.py @@ -169,13 +169,13 @@ def from_files( >>> _ = [os.remove(f"video_{i}.mp4") for i in range(1, 4)] >>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)] """ - ds_kw = dict( - clip_sampler=clip_sampler, - clip_duration=clip_duration, - clip_sampler_kwargs=clip_sampler_kwargs, - decode_audio=decode_audio, - decoder=decoder, - ) + ds_kw = { + "clip_sampler": clip_sampler, + "clip_duration": clip_duration, + "clip_sampler_kwargs": clip_sampler_kwargs, + "decode_audio": decode_audio, + "decoder": decoder, + } train_input = input_cls( RunningStage.TRAINING, @@ -334,13 +334,13 @@ def from_folders( >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") """ - ds_kw = dict( - clip_sampler=clip_sampler, - clip_duration=clip_duration, - clip_sampler_kwargs=clip_sampler_kwargs, - decode_audio=decode_audio, - decoder=decoder, - ) + ds_kw = { + "clip_sampler": clip_sampler, + "clip_duration": clip_duration, + "clip_sampler_kwargs": clip_sampler_kwargs, + "decode_audio": decode_audio, + "decoder": decoder, + } train_input = input_cls( RunningStage.TRAINING, @@ -517,13 +517,13 @@ def from_data_frame( >>> del train_data_frame >>> del predict_data_frame """ - ds_kw = dict( - clip_sampler=clip_sampler, - clip_duration=clip_duration, - clip_sampler_kwargs=clip_sampler_kwargs, - decode_audio=decode_audio, - decoder=decoder, - ) + ds_kw = { + "clip_sampler": clip_sampler, + "clip_duration": clip_duration, + "clip_sampler_kwargs": clip_sampler_kwargs, + "decode_audio": decode_audio, + "decoder": decoder, + } train_data = (train_data_frame, input_field, target_fields, train_videos_root, train_resolver) val_data = (val_data_frame, input_field, target_fields, val_videos_root, val_resolver) @@ -915,13 +915,13 @@ def from_csv( >>> os.remove("train_data.tsv") >>> os.remove("predict_data.tsv") """ - ds_kw = dict( - clip_sampler=clip_sampler, - clip_duration=clip_duration, - clip_sampler_kwargs=clip_sampler_kwargs, - decode_audio=decode_audio, - decoder=decoder, - ) + ds_kw = { + "clip_sampler": clip_sampler, + "clip_duration": clip_duration, + "clip_sampler_kwargs": clip_sampler_kwargs, + "decode_audio": decode_audio, + "decoder": decoder, + } train_data = (train_file, input_field, target_fields, train_videos_root, train_resolver) val_data = (val_file, input_field, target_fields, val_videos_root, val_resolver) @@ -1072,13 +1072,13 @@ def from_fiftyone( >>> del train_dataset >>> del predict_dataset """ - ds_kw = dict( - clip_sampler=clip_sampler, - clip_duration=clip_duration, - clip_sampler_kwargs=clip_sampler_kwargs, - decode_audio=decode_audio, - decoder=decoder, - ) + ds_kw = { + "clip_sampler": clip_sampler, + "clip_duration": clip_duration, + "clip_sampler_kwargs": clip_sampler_kwargs, + "decode_audio": decode_audio, + "decoder": decoder, + } train_input = input_cls( RunningStage.TRAINING, @@ -1198,14 +1198,14 @@ def from_labelstudio( multi_label=multi_label, ) - ds_kw = dict( - clip_sampler=clip_sampler, - clip_duration=clip_duration, - clip_sampler_kwargs=clip_sampler_kwargs, - video_sampler=video_sampler, - decode_audio=decode_audio, - decoder=decoder, - ) + ds_kw = { + "clip_sampler": clip_sampler, + "clip_duration": clip_duration, + "clip_sampler_kwargs": clip_sampler_kwargs, + "video_sampler": video_sampler, + "decode_audio": decode_audio, + "decoder": decoder, + } train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) ds_kw["parameters"] = getattr(train_input, "parameters", None) diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index 0212c40215..bebf8e18d4 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -66,7 +66,7 @@ def test_from_filepaths(tmpdir, file_generator): imgs, labels = data["input"], data["target"] assert imgs.shape == (2, channels, 128, 128) assert labels.shape == (2,) - assert sorted(list(labels.numpy())) == [1, 2] + assert sorted(labels.numpy()) == [1, 2] @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @@ -136,7 +136,7 @@ def test_from_filepaths_numpy(tmpdir): imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 128, 128) assert labels.shape == (2,) - assert sorted(list(labels.numpy())) == [1, 2] + assert sorted(labels.numpy()) == [1, 2] @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @@ -242,7 +242,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): test_files=[image_b, image_b], test_targets=[[0, 0, 1], [1, 1, 0]], batch_size=2, - transform_kwargs=dict(spectrogram_size=(64, 64)), + transform_kwargs={"spectrogram_size": (64, 64)}, ) # disable visualisation for testing assert dm.data_fetcher.block_viz_window is True diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py index 4751c7ad8e..29cc516bcf 100644 --- a/tests/audio/speech_recognition/test_model.py +++ b/tests/audio/speech_recognition/test_model.py @@ -29,7 +29,7 @@ class TestSpeechRecognition(TaskTester): task = SpeechRecognition - task_kwargs = dict(backbone=TEST_BACKBONE) + task_kwargs = {"backbone": TEST_BACKBONE} cli_command = "speech_recognition" is_testing = _TOPIC_AUDIO_AVAILABLE is_available = _TOPIC_AUDIO_AVAILABLE diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py index c95652e37f..c77ea49457 100644 --- a/tests/core/serve/test_components.py +++ b/tests/core/serve/test_components.py @@ -29,7 +29,7 @@ def test_model_compute_dependencies(lightning_squeezenet1_1_obj): "target_key": "tag", } ] - assert list(map(lambda x: x._asdict(), comp1._flashserve_meta_.connections)) == res + assert [x._asdict() for x in comp1._flashserve_meta_.connections] == res assert list(comp2._flashserve_meta_.connections) == [] @@ -48,7 +48,7 @@ def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_ob "target_key": "tag", } ] - assert list(map(lambda x: x._asdict(), comp2._flashserve_meta_.connections)) == res + assert [x._asdict() for x in comp2._flashserve_meta_.connections] == res assert list(comp1._flashserve_meta_.connections) == [] diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py index 1c9712d63c..12b25d73f1 100644 --- a/tests/core/serve/test_dag/test_optimization.py +++ b/tests/core/serve/test_dag/test_optimization.py @@ -1116,7 +1116,7 @@ def test_SubgraphCallable(): f3 = SubgraphCallable(dsk, "g", ["in1", "in2"], name="test") assert f != f3 - assert dict(f=None) + assert {"f": None} assert hash(SubgraphCallable(None, None, [None])) assert hash(f3) != hash(f2) diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index bc7e6e9b3b..54b5a7fd56 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -76,16 +76,16 @@ def test_add_argparse_args_redefined(cli_args): @pytest.mark.parametrize( ["cli_args", "expected"], [ - ("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")), + ("--auto_lr_find=True --auto_scale_batch_size=power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}), ( "--auto_lr_find any_string --auto_scale_batch_size ON", - dict(auto_lr_find="any_string", auto_scale_batch_size=True), + {"auto_lr_find": "any_string", "auto_scale_batch_size": True}, ), - ("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)), - ("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)), - ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)), - ("--limit_train_batches=100", dict(limit_train_batches=100)), - ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)), + ("--auto_lr_find=Yes --auto_scale_batch_size=On", {"auto_lr_find": True, "auto_scale_batch_size": True}), + ("--auto_lr_find Off --auto_scale_batch_size No", {"auto_lr_find": False, "auto_scale_batch_size": False}), + ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", {"auto_lr_find": True, "auto_scale_batch_size": False}), + ("--limit_train_batches=100", {"limit_train_batches": 100}), + ("--limit_train_batches 0.8", {"limit_train_batches": 0.8}), ], ) def test_parse_args_parsing(cli_args, expected): @@ -105,9 +105,9 @@ def test_parse_args_parsing(cli_args, expected): @pytest.mark.parametrize( ["cli_args", "expected", "instantiate"], [ - (["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False), - (["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False), - (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True), + (["--gpus", "[0, 2]"], {"gpus": [0, 2]}, False), + (["--tpu_cores=[1,3]"], {"tpu_cores": [1, 3]}, False), + (['--accumulate_grad_batches={"5":3,"10":20}'], {"accumulate_grad_batches": {5: 3, 10: 20}}, True), ], ) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): @@ -151,13 +151,13 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): ["cli_args", "extra_args"], [ ({}, {}), - (dict(logger=False), {}), - (dict(logger=False), dict(logger=True)), - (dict(logger=False), dict(enable_checkpointing=True)), + ({"logger": False}, {}), + ({"logger": False}, {"logger": True}), + ({"logger": False}, {"enable_checkpointing": True}), ], ) def test_init_from_argparse_args(cli_args, extra_args): - unknown_args = dict(unknown_arg=0) + unknown_args = {"unknown_arg": 0} # unkown args in the argparser/namespace should be ignored with patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init: @@ -192,8 +192,8 @@ def trainer_builder( def test_lightning_cli(trainer_class, model_class, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" - expected_model = dict(model_param=7) - expected_trainer = dict(limit_train_batches=100) + expected_model = {"model_param": 7} + expected_trainer = {"limit_train_batches": 100} def fit(trainer, model): for k, v in expected_model.items(): @@ -236,15 +236,15 @@ def on_fit_start(self): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ - dict( - class_path="pytorch_lightning.callbacks.LearningRateMonitor", - init_args=dict(logging_interval="epoch", log_momentum=True), - ), - dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")), + { + "class_path": "pytorch_lightning.callbacks.LearningRateMonitor", + "init_args": {"logging_interval": "epoch", "log_momentum": True}, + }, + {"class_path": "pytorch_lightning.callbacks.ModelCheckpoint", "init_args": {"monitor": "NAME"}}, ] with patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]): - cli = LightningCLI(TestModelCallbacks, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + cli = LightningCLI(TestModelCallbacks, trainer_defaults={"default_root_dir": str(tmpdir), "fast_dev_run": True}) assert cli.trainer.ran_asserts @@ -279,10 +279,12 @@ def on_fit_start(self): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") def test_lightning_cli_args_cluster_environments(tmpdir): - plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] + plugins = [{"class_path": "pytorch_lightning.plugins.environments.SLURMEnvironment"}] with patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): - cli = LightningCLI(TestModelClusterEnv, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + cli = LightningCLI( + TestModelClusterEnv, trainer_defaults={"default_root_dir": str(tmpdir), "fast_dev_run": True} + ) assert cli.trainer.ran_asserts @@ -337,11 +339,11 @@ def test_lightning_cli_save_config_cases(tmpdir): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") def test_lightning_cli_config_and_subclass_mode(tmpdir): - config = dict( - model=dict(class_path="tests.helpers.boring_model.BoringModel"), - data=dict(class_path="tests.helpers.boring_model.BoringDataModule", init_args=dict(data_dir=str(tmpdir))), - trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, enable_model_summary=False), - ) + config = { + "model": {"class_path": "tests.helpers.boring_model.BoringModel"}, + "data": {"class_path": "tests.helpers.boring_model.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}}, + "trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False}, + } config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(yaml.dump(config)) @@ -579,8 +581,8 @@ def on_exception(self, execption): @pytest.mark.parametrize( "trainer_kwargs", ( - dict(accelerator="cpu", strategy="ddp"), - dict(accelerator="cpu", strategy="ddp", plugins="ddp_find_unused_parameters_false"), + {"accelerator": "cpu", "strategy": "ddp"}, + {"accelerator": "cpu", "strategy": "ddp", "plugins": "ddp_find_unused_parameters_false"}, ), ) @pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") @@ -673,14 +675,14 @@ def add_arguments_to_parser(self, parser): parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR)) - optimizer_arg = dict( - class_path="torch.optim.Adam", - init_args=dict(lr=0.01), - ) - lr_scheduler_arg = dict( - class_path="torch.optim.lr_scheduler.StepLR", - init_args=dict(step_size=50), - ) + optimizer_arg = { + "class_path": "torch.optim.Adam", + "init_args": {"lr": 0.01}, + } + lr_scheduler_arg = { + "class_path": "torch.optim.lr_scheduler.StepLR", + "init_args": {"step_size": 50}, + } cli_args = [ f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 74cd0603be..3b8f3bd49a 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -56,7 +56,7 @@ def simple_datamodule(tmpdir): test_targets=[0] * n, batch_size=2, num_workers=0, - transform_kwargs=dict(image_size=image_size), + transform_kwargs={"image_size": image_size}, ) return dm diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index bdaf2aa987..87e5ecae7b 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -75,7 +75,7 @@ def test_from_filepaths_smoke(tmpdir): imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) - assert sorted(list(labels.numpy())) == [1, 2] + assert sorted(labels.numpy()) == [1, 2] @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -111,19 +111,19 @@ def test_from_data_frame_smoke(tmpdir): imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) - assert sorted(list(labels.numpy())) == [0] + assert sorted(labels.numpy()) == [0] data = next(iter(img_data.val_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) - assert sorted(list(labels.numpy())) == [1] + assert sorted(labels.numpy()) == [1] data = next(iter(img_data.test_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) - assert sorted(list(labels.numpy())) == [1] + assert sorted(labels.numpy()) == [1] data = next(iter(img_data.predict_dataloader())) imgs = data["input"] @@ -500,21 +500,21 @@ def test_from_fiftyone(tmpdir): imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) - assert sorted(list(labels.numpy())) == [0, 1] + assert sorted(labels.numpy()) == [0, 1] # check val data data = next(iter(img_data.val_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) - assert sorted(list(labels.numpy())) == [0, 1] + assert sorted(labels.numpy()) == [0, 1] # check test data data = next(iter(img_data.test_dataloader())) imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) - assert sorted(list(labels.numpy())) == [0, 1] + assert sorted(labels.numpy()) == [0, 1] @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py index cd66fb2673..0236188768 100644 --- a/tests/image/classification/test_training_strategies.py +++ b/tests/image/classification/test_training_strategies.py @@ -78,7 +78,7 @@ def _test_learn2learning_training_strategies(gpus, training_strategy, tmpdir, ac train_targets=[0] * n + [1] * n + [2] * n + [3] * n, batch_size=1, num_workers=0, - transform_kwargs=dict(image_size=image_size), + transform_kwargs={"image_size": image_size}, ) model = ImageClassifier( diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index d5c04a2972..7bc24c4e31 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -165,7 +165,7 @@ def test_image_detector_data_from_coco(tmpdir): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) datamodule = ObjectDetectionData.from_coco( - train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1, transform_kwargs=dict(image_size=128) + train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1, transform_kwargs={"image_size": 128} ) data = next(iter(datamodule.train_dataloader())) @@ -181,7 +181,7 @@ def test_image_detector_data_from_coco(tmpdir): test_ann_file=coco_ann_path, batch_size=1, num_workers=0, - transform_kwargs=dict(image_size=128), + transform_kwargs={"image_size": 128}, ) data = next(iter(datamodule.val_dataloader())) @@ -198,7 +198,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): train_dataset = _create_synth_fiftyone_dataset(tmpdir) datamodule = ObjectDetectionData.from_fiftyone( - train_dataset=train_dataset, batch_size=1, transform_kwargs=dict(image_size=128) + train_dataset=train_dataset, batch_size=1, transform_kwargs={"image_size": 128} ) data = next(iter(datamodule.train_dataloader())) @@ -211,7 +211,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): test_dataset=train_dataset, batch_size=1, num_workers=0, - transform_kwargs=dict(image_size=128), + transform_kwargs={"image_size": 128}, ) data = next(iter(datamodule.val_dataloader())) @@ -227,7 +227,7 @@ def test_image_detector_data_from_fiftyone(tmpdir): def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) datamodule = ObjectDetectionData.from_files( - predict_files=predict_files, batch_size=1, transform_kwargs=dict(image_size=128) + predict_files=predict_files, batch_size=1, transform_kwargs={"image_size": 128} ) data = next(iter(datamodule.predict_dataloader())) sample = data[0] @@ -238,7 +238,7 @@ def test_image_detector_data_from_files(tmpdir): def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) datamodule = ObjectDetectionData.from_folders( - predict_folder=predict_folder, batch_size=1, transform_kwargs=dict(image_size=128) + predict_folder=predict_folder, batch_size=1, transform_kwargs={"image_size": 128} ) data = next(iter(datamodule.predict_dataloader())) sample = data[0] diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 26e77dfa84..671a1ea330 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -30,9 +30,9 @@ class TestImageEmbedder(TaskTester): task = ImageEmbedder - task_kwargs = dict( - backbone="resnet18", - ) + task_kwargs = { + "backbone": "resnet18", + } is_testing = _TOPIC_IMAGE_AVAILABLE is_available = _TOPIC_IMAGE_AVAILABLE @@ -119,7 +119,7 @@ def test_only_embedding(backbone, embedding_size): datamodule = ImageClassificationData.from_datasets( predict_dataset=FakeData(8), batch_size=4, - transform_kwargs=dict(image_size=(224, 224)), + transform_kwargs={"image_size": (224, 224)}, ) embedder = ImageEmbedder(backbone=backbone) diff --git a/tests/image/instance_segm/test_data.py b/tests/image/instance_segm/test_data.py index 1d00ce8e7e..fe00336ba3 100644 --- a/tests/image/instance_segm/test_data.py +++ b/tests/image/instance_segm/test_data.py @@ -26,7 +26,7 @@ def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) datamodule = InstanceSegmentationData.from_files( - predict_files=predict_files, batch_size=2, transform_kwargs=dict(image_size=(128, 128)) + predict_files=predict_files, batch_size=2, transform_kwargs={"image_size": (128, 128)} ) data = next(iter(datamodule.predict_dataloader())) sample = data[0] @@ -37,7 +37,7 @@ def test_image_detector_data_from_files(tmpdir): def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) datamodule = InstanceSegmentationData.from_folders( - predict_folder=predict_folder, batch_size=2, transform_kwargs=dict(image_size=(128, 128)) + predict_folder=predict_folder, batch_size=2, transform_kwargs={"image_size": (128, 128)} ) data = next(iter(datamodule.predict_dataloader())) sample = data[0] diff --git a/tests/image/instance_segm/test_model.py b/tests/image/instance_segm/test_model.py index 872a5e0e4d..e4637cb578 100644 --- a/tests/image/instance_segm/test_model.py +++ b/tests/image/instance_segm/test_model.py @@ -140,7 +140,7 @@ def test_model(coco_instances, backbone, head): train_folder=coco_instances.train_folder, train_ann_file=coco_instances.train_ann_file, predict_folder=coco_instances.predict_folder, - transform_kwargs=dict(image_size=(128, 128)), + transform_kwargs={"image_size": (128, 128)}, batch_size=2, ) diff --git a/tests/image/keypoint_detection/test_data.py b/tests/image/keypoint_detection/test_data.py index 7aefde8844..5de9b7b266 100644 --- a/tests/image/keypoint_detection/test_data.py +++ b/tests/image/keypoint_detection/test_data.py @@ -23,7 +23,7 @@ def test_image_detector_data_from_files(tmpdir): predict_files = _create_synth_files_dataset(tmpdir) datamodule = KeypointDetectionData.from_files( - predict_files=predict_files, batch_size=1, transform_kwargs=dict(image_size=(128, 128)) + predict_files=predict_files, batch_size=1, transform_kwargs={"image_size": (128, 128)} ) data = next(iter(datamodule.predict_dataloader())) sample = data[0] @@ -34,7 +34,7 @@ def test_image_detector_data_from_files(tmpdir): def test_image_detector_data_from_folders(tmpdir): predict_folder = _create_synth_folders_dataset(tmpdir) datamodule = KeypointDetectionData.from_folders( - predict_folder=predict_folder, batch_size=1, transform_kwargs=dict(image_size=(128, 128)) + predict_folder=predict_folder, batch_size=1, transform_kwargs={"image_size": (128, 128)} ) data = next(iter(datamodule.predict_dataloader())) sample = data[0] diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py index be47ad5204..8c18a45826 100644 --- a/tests/image/keypoint_detection/test_model.py +++ b/tests/image/keypoint_detection/test_model.py @@ -147,7 +147,7 @@ def test_model(coco_keypoints, backbone, head): train_folder=coco_keypoints.train_folder, train_ann_file=coco_keypoints.train_ann_file, predict_folder=coco_keypoints.predict_folder, - transform_kwargs=dict(image_size=(128, 128)), + transform_kwargs={"image_size": (128, 128)}, batch_size=2, ) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 073c97205b..e19dcd2940 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -163,15 +163,15 @@ def test_from_csv_multilabel(tmpdir): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) @@ -225,15 +225,15 @@ def test_from_json_multilabel(tmpdir): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) @@ -289,15 +289,15 @@ def test_from_json_with_field_multilabel(tmpdir): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) @@ -351,15 +351,15 @@ def test_from_parquet_multilabel(tmpdir): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) @@ -411,15 +411,15 @@ def test_from_data_frame_multilabel(): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) @@ -473,15 +473,15 @@ def test_from_hf_datasets_multilabel(): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) @@ -535,15 +535,15 @@ def test_from_lists_multilabel(): assert dm.multi_label batch = next(iter(dm.train_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) - assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) + assert all(label in [0, 1] for label in batch[DataKeys.TARGET][0]) assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) From 5009de45c3e3ef3c5015b76038c0bb6a967c6a35 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 11 May 2023 19:25:53 +0200 Subject: [PATCH 35/39] fix codecov --- .codecov.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 646217fb2d..2f1fb68728 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -21,7 +21,6 @@ coverage: project: default: informational: true - against: auto target: 99% # specify the target coverage for each commit status threshold: 30% # allow this little decrease on project # https://github.com/codecov/support/wiki/Filtering-Branches @@ -31,7 +30,6 @@ coverage: patch: default: informational: true - against: auto target: 50% # specify the target "X%" coverage to hit # threshold: 50% # allow this much decrease on patch changes: false From de70b84ea5d5f207056a67fa243c95b708dfc5b7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 20:24:18 +0200 Subject: [PATCH 36/39] ruff: enable RET (#1554) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 + .../speech_recognition/output_transform.py | 3 +- src/flash/core/data/data_module.py | 2 +- src/flash/core/data/io/input.py | 2 +- src/flash/core/data/io/input_transform.py | 6 +-- src/flash/core/data/transforms.py | 6 +-- .../core/data/utilities/classification.py | 13 +++--- .../core/integrations/labelstudio/input.py | 6 +-- .../integrations/labelstudio/visualizer.py | 43 +++++++++---------- .../integrations/pytorch_tabular/adapter.py | 4 +- .../integrations/pytorch_tabular/backbones.py | 3 +- src/flash/core/model.py | 4 +- src/flash/core/registry.py | 2 + src/flash/core/serve/dag/optimization.py | 4 +- src/flash/core/serve/dag/visualize.py | 2 +- src/flash/core/serve/execution.py | 3 +- src/flash/core/serve/interfaces/http.py | 3 +- src/flash/core/serve/server.py | 2 +- src/flash/core/utilities/apply_func.py | 1 + .../integrations/learn2learn.py | 3 +- src/flash/image/detection/input.py | 3 +- .../embedding/transforms/vissl_transforms.py | 8 +--- src/flash/image/embedding/vissl/adapter.py | 8 +--- .../pointcloud/detection/open3d_ml/input.py | 1 + .../open3d_ml/sequences_dataset.py | 14 +++--- src/flash/tabular/classification/input.py | 5 ++- src/flash/tabular/classification/model.py | 3 +- src/flash/tabular/classification/utils.py | 4 +- src/flash/tabular/regression/input.py | 5 ++- src/flash/tabular/regression/model.py | 3 +- src/flash/video/classification/input.py | 22 +++++----- src/flash/video/classification/utils.py | 4 +- tests/helpers/boring_model.py | 3 +- .../classification/test_active_learning.py | 3 +- tests/image/detection/test_data.py | 4 +- 35 files changed, 86 insertions(+), 117 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 79a4f52855..0c6b5f6da1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ select = [ ] extend-select = [ "C4", # see: https://pypi.org/project/flake8-comprehensions + "RET", # see: https://pypi.org/project/flake8-return ] ignore = [ "E731", # Do not assign a lambda expression, use a def diff --git a/src/flash/audio/speech_recognition/output_transform.py b/src/flash/audio/speech_recognition/output_transform.py index 1cd1314106..4a7a3319e4 100644 --- a/src/flash/audio/speech_recognition/output_transform.py +++ b/src/flash/audio/speech_recognition/output_transform.py @@ -33,8 +33,7 @@ def __init__(self, backbone: str): def per_batch_transform(self, batch: Any) -> Any: # converts logits into greedy transcription pred_ids = torch.argmax(batch, dim=-1) - transcriptions = self._tokenizer.batch_decode(pred_ids) - return transcriptions + return self._tokenizer.batch_decode(pred_ids) def __getstate__(self): # TODO: Find out why this is being pickled state = self.__dict__.copy() diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py index 8fb75e5442..99f805b58d 100644 --- a/src/flash/core/data/data_module.py +++ b/src/flash/core/data/data_module.py @@ -462,7 +462,7 @@ def _show_batch( """This function is used to handle transforms profiling for batch visualization.""" # don't show in CI if os.getenv("FLASH_TESTING", "0") == "1": - return None + return iter_name = f"_{stage}_iter" if not hasattr(self, iter_name): diff --git a/src/flash/core/data/io/input.py b/src/flash/core/data/io/input.py index 6e1491d44e..7e42c448f2 100644 --- a/src/flash/core/data/io/input.py +++ b/src/flash/core/data/io/input.py @@ -111,7 +111,7 @@ def _validate_input(input: "InputBase") -> None: if input.data is not None: if isinstance(input, Input) and not _has_len(input.data): raise RuntimeError("`Input.data` is not a sequence with a defined length. Use `IterableInput` instead.") - elif isinstance(input, IterableInput) and _has_len(input.data): + if isinstance(input, IterableInput) and _has_len(input.data): raise RuntimeError("`IterableInput.data` is a sequence with a defined length. Use `Input` instead.") diff --git a/src/flash/core/data/io/input_transform.py b/src/flash/core/data/io/input_transform.py index f78d1c1c1b..49d754a8df 100644 --- a/src/flash/core/data/io/input_transform.py +++ b/src/flash/core/data/io/input_transform.py @@ -840,14 +840,13 @@ def create_worker_input_transform_processor( worker_collate_fn, _ = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) - worker_input_transform_processor = _InputTransformProcessor( + return _InputTransformProcessor( input_transform, worker_collate_fn, input_transform._per_sample_transform, input_transform._per_batch_transform, running_stage, ) - return worker_input_transform_processor def create_device_input_transform_processor( @@ -858,7 +857,7 @@ def create_device_input_transform_processor( _, device_collate_fn = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) - device_input_transform_processor = _InputTransformProcessor( + return _InputTransformProcessor( input_transform, device_collate_fn, input_transform._per_sample_transform_on_device, @@ -867,4 +866,3 @@ def create_device_input_transform_processor( apply_per_sample_transform=device_collate_fn != input_transform._identity, on_device=True, ) - return device_input_transform_processor diff --git a/src/flash/core/data/transforms.py b/src/flash/core/data/transforms.py index b7696f6433..6aa428e8c0 100644 --- a/src/flash/core/data/transforms.py +++ b/src/flash/core/data/transforms.py @@ -53,9 +53,9 @@ def forward(self, x: Any) -> Any: x_ = self.transform(**x_) if isinstance(x, dict): x.update({self._mapping_rev.get(k, k): x_[k] for k in self._mapping_rev if k in x_}) - else: - x = x_["image"] - return x + return x + + return x_["image"] class ApplyToKeys(nn.Sequential): diff --git a/src/flash/core/data/utilities/classification.py b/src/flash/core/data/utilities/classification.py index 6ba85486a1..19a40e0449 100644 --- a/src/flash/core/data/utilities/classification.py +++ b/src/flash/core/data/utilities/classification.py @@ -356,11 +356,10 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]: # TODO: This could be a dangerous assumption if people happen to have a label that contains a comma or space if "," in target: return CommaDelimitedMultiLabelTargetFormatter - elif " " in target: + if " " in target: return SpaceDelimitedTargetFormatter - else: - return SingleLabelTargetFormatter - elif _is_list_like(target): + return SingleLabelTargetFormatter + if _is_list_like(target): if isinstance(target[0], str): return MultiLabelTargetFormatter target = _as_list(target) @@ -369,7 +368,7 @@ def _get_target_formatter_type(target: Any) -> Type[TargetFormatter]: if sum(target) == 1: return SingleBinaryTargetFormatter return MultiBinaryTargetFormatter - elif any(isinstance(t, float) for t in target): + if any(isinstance(t, float) for t in target): return MultiSoftTargetFormatter return MultiNumericTargetFormatter return SingleNumericTargetFormatter @@ -393,9 +392,9 @@ def _resolve_target_formatter(a: Type[TargetFormatter], b: Type[TargetFormatter] """ if a is b: return a - elif a in _RESOLUTION_MAPPING and b in _RESOLUTION_MAPPING[a]: + if a in _RESOLUTION_MAPPING and b in _RESOLUTION_MAPPING[a]: return b - elif b in _RESOLUTION_MAPPING and a in _RESOLUTION_MAPPING[b]: + if b in _RESOLUTION_MAPPING and a in _RESOLUTION_MAPPING[b]: return a raise ValueError( "Found inconsistent target formats. All targets should be either: single values, lists of values, or " diff --git a/src/flash/core/integrations/labelstudio/input.py b/src/flash/core/integrations/labelstudio/input.py index 723bf857e8..a8956aed82 100644 --- a/src/flash/core/integrations/labelstudio/input.py +++ b/src/flash/core/integrations/labelstudio/input.py @@ -141,11 +141,10 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: label = _get_labels_from_sample(sample["label"], self.parameters.classes) # delete label from input data del sample["label"] - result = { + return { DataKeys.INPUT: sample, DataKeys.TARGET: label, } - return result @staticmethod def _split_train_test_data(data: Dict, multi_label: bool = False) -> List[Dict]: @@ -241,11 +240,10 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: p = sample["file_upload"] # loading image image = load_image(p) - result = { + return { DataKeys.INPUT: image, DataKeys.TARGET: _get_labels_from_sample(sample["label"], self.parameters.classes), } - return result class LabelStudioTextClassificationInput(LabelStudioInput): diff --git a/src/flash/core/integrations/labelstudio/visualizer.py b/src/flash/core/integrations/labelstudio/visualizer.py index 180a6d5d12..166ebc1305 100644 --- a/src/flash/core/integrations/labelstudio/visualizer.py +++ b/src/flash/core/integrations/labelstudio/visualizer.py @@ -47,26 +47,26 @@ def show_tasks(self, predictions, export_json=None): else: task["predictions"] = [temp] return _raw_data - else: - print("No export file provided, meta information is generated!") - final_results = [] - for res in results: - temp = { - "result": [res], - "id": meta["max_predictions_id"], - "model_version": "", - "score": 0.0, - "task": meta["max_predictions_id"], - } - task = { - "id": meta["max_predictions_id"], - "predictions": [temp], - "data": {data_type: ""}, - "project": 1, - } - meta["max_predictions_id"] = meta["max_predictions_id"] + 1 - final_results.append(task) - return final_results + + print("No export file provided, meta information is generated!") + final_results = [] + for res in results: + temp = { + "result": [res], + "id": meta["max_predictions_id"], + "model_version": "", + "score": 0.0, + "task": meta["max_predictions_id"], + } + task = { + "id": meta["max_predictions_id"], + "predictions": [temp], + "data": {data_type: ""}, + "project": 1, + } + meta["max_predictions_id"] = meta["max_predictions_id"] + 1 + final_results.append(task) + return final_results def _construct_result(self, pred): """Construction Label Studio result from data source and prediction values.""" @@ -79,7 +79,7 @@ def _construct_result(self, pred): data_type = list(self.parameters.data_types)[0] # get tag type, if len(tag_types) > 1 take first tag tag_type = list(self.parameters.tag_types)[0] - js = { + return { "result": [ { "id": "".join( @@ -93,7 +93,6 @@ def _construct_result(self, pred): } ] } - return js def launch_app(datamodule: DataModule) -> "App": diff --git a/src/flash/core/integrations/pytorch_tabular/adapter.py b/src/flash/core/integrations/pytorch_tabular/adapter.py index 21b95b6552..b9aca1f243 100644 --- a/src/flash/core/integrations/pytorch_tabular/adapter.py +++ b/src/flash/core/integrations/pytorch_tabular/adapter.py @@ -51,15 +51,13 @@ def from_task( "continuous_dim": num_features - len(categorical_fields), "output_dim": output_dim, } - adapter = cls( + return cls( task_type, task.backbones.get(backbone)( task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs ), ) - return adapter - def convert_batch(self, batch): new_batch = { "continuous": batch[DataKeys.INPUT][1], diff --git a/src/flash/core/integrations/pytorch_tabular/backbones.py b/src/flash/core/integrations/pytorch_tabular/backbones.py index 6f2ca58d24..395d1051e8 100644 --- a/src/flash/core/integrations/pytorch_tabular/backbones.py +++ b/src/flash/core/integrations/pytorch_tabular/backbones.py @@ -76,10 +76,9 @@ def load_pytorch_tabular( OmegaConf.create(parameters), OmegaConf.to_container(model_config), ) - model = model_callable( + return model_callable( config=config, custom_loss=loss_fn, custom_metrics=metrics, inferred_config=DictConfig(config) ) - return model for model_config_class, name in zip( [ diff --git a/src/flash/core/model.py b/src/flash/core/model.py index 5f9535c3e3..fd3e3dc86c 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -487,8 +487,7 @@ def _get_optimizer_class_from_registry(self, optimizer_key: str) -> Optimizer: f"\nUse `{self.__class__.__name__}.available_optimizers()` to list the available optimizers." f"\nList of available Optimizers: {self.available_optimizers()}." ) - optimizer_fn = self.optimizers_registry.get(optimizer_key.lower()) - return optimizer_fn + return self.optimizers_registry.get(optimizer_key.lower()) def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: """Implement how optimizer and optionally learning rate schedulers should be configured.""" @@ -809,6 +808,7 @@ def configure_callbacks(self): # used only for CI if flash._IS_TESTING and torch.cuda.is_available(): return [BenchmarkConvergenceCI()] + return None @requires("serve") def run_serve_sanity_check( diff --git a/src/flash/core/registry.py b/src/flash/core/registry.py index b7f921cc14..4b9fdb28da 100644 --- a/src/flash/core/registry.py +++ b/src/flash/core/registry.py @@ -151,6 +151,7 @@ def _find_matching_index(self, item: _REGISTERED_FUNCTION) -> Optional[int]: for idx, fn in enumerate(self.functions): if all(fn[k] == item[k] for k in ("fn", "name", "metadata")): return idx + return None def __call__( self, @@ -315,6 +316,7 @@ def _register_function( for registry in self.registries: if getattr(registry, "_register_function", None) is not None: return registry._register_function(fn, name=name, override=override, metadata=metadata) + return None def available_keys(self) -> List[str]: return list(itertools.chain.from_iterable(registry.available_keys() for registry in self.registries)) diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index 7b84d264d7..6ae9008a0b 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -386,6 +386,7 @@ def _enforce_max_key_limit(key_name): names.append(first_key[0]) concatenated_name = "-".join(names) return (_enforce_max_key_limit(concatenated_name),) + first_key[1:] + return None # PEP-484 compliant singleton constant @@ -885,8 +886,7 @@ def __repr__(self): def __eq__(self, other): is_key = self.outkey == other.outkey and set(self.inkeys) == set(other.inkeys) - is_eq = type(self) is type(other) and self.name == other.name and is_key - return is_eq + return type(self) is type(other) and self.name == other.name and is_key def __ne__(self, other): return not self.__eq__(other) diff --git a/src/flash/core/serve/dag/visualize.py b/src/flash/core/serve/dag/visualize.py index bc847d984a..f43f644988 100644 --- a/src/flash/core/serve/dag/visualize.py +++ b/src/flash/core/serve/dag/visualize.py @@ -68,6 +68,6 @@ def visualize( data = g.pipe(format=format) fhandle.seek(0) fhandle.write(data) - return + return None return g diff --git a/src/flash/core/serve/execution.py b/src/flash/core/serve/execution.py index 1039f29a30..e5ec129115 100644 --- a/src/flash/core/serve/execution.py +++ b/src/flash/core/serve/execution.py @@ -317,7 +317,7 @@ def build_composition( toposort_keys = toposort(inlined_culled_dsk) # construct results - res = TaskComposition( + return TaskComposition( dsk=inlined_culled_dsk, sortkeys=toposort_keys, get_keys=initial_task_dsk.output_keys, @@ -325,7 +325,6 @@ def build_composition( ep_dsk_output_keys=initial_task_dsk.result_dsk_map, pre_optimization_dsk=initial_task_dsk.merged_dsk, ) - return res def _verify_no_cycles(dsk: Dict[str, tuple], out_keys: List[str], endpoint_name: str): diff --git a/src/flash/core/serve/interfaces/http.py b/src/flash/core/serve/interfaces/http.py index 915b3d5f17..9fe1f36b67 100644 --- a/src/flash/core/serve/interfaces/http.py +++ b/src/flash/core/serve/interfaces/http.py @@ -94,8 +94,7 @@ def endpoint_visualization(request: Request): f.seek(0) raw = f.read() encoded = base64.b64encode(raw).decode("ascii") - res = templates.TemplateResponse("dag.html", {"request": request, "encoded_image": encoded}) - return res + return templates.TemplateResponse("dag.html", {"request": request, "encoded_image": encoded}) return endpoint_visualization diff --git a/src/flash/core/serve/server.py b/src/flash/core/serve/server.py index ced1cc5fc9..aeaf00c034 100644 --- a/src/flash/core/serve/server.py +++ b/src/flash/core/serve/server.py @@ -39,7 +39,7 @@ def serve(self, host: str = "127.0.0.1", port: int = 8000): port number to expose the running server on """ if FLASH_DISABLE_SERVE: - return + return None if not self.TESTING: # pragma: no cover app = self.http_app() diff --git a/src/flash/core/utilities/apply_func.py b/src/flash/core/utilities/apply_func.py index b7e5ff7c21..c01c91280f 100644 --- a/src/flash/core/utilities/apply_func.py +++ b/src/flash/core/utilities/apply_func.py @@ -29,3 +29,4 @@ def get_callable_dict(fn: Union[nn.Module, Callable, Mapping, Sequence]) -> Unio return {get_callable_name(f): f for f in fn} if callable(fn): return {get_callable_name(fn): fn} + return None diff --git a/src/flash/image/classification/integrations/learn2learn.py b/src/flash/image/classification/integrations/learn2learn.py index 8912a66a37..d15d6b39ad 100644 --- a/src/flash/image/classification/integrations/learn2learn.py +++ b/src/flash/image/classification/integrations/learn2learn.py @@ -140,6 +140,5 @@ def __next__(self): for _ in range(self.worker_world_size): task_descriptions.append(self.taskset.sample_task_description()) - data = self.taskset.get_task(task_descriptions[self.worker_rank]) self.counter += 1 - return data + return self.taskset.get_task(task_descriptions[self.worker_rank]) diff --git a/src/flash/image/detection/input.py b/src/flash/image/detection/input.py index 97daee9f16..c20156b604 100644 --- a/src/flash/image/detection/input.py +++ b/src/flash/image/detection/input.py @@ -219,8 +219,7 @@ def _reformat_bbox(xmin, ymin, box_w, box_h, img_w, img_h): box_h *= img_h xmax = xmin + box_w ymax = ymin + box_h - output_bbox = [xmin, ymin, xmax, ymax] - return output_bbox + return [xmin, ymin, xmax, ymax] class ObjectDetectionFiftyOneInput(IceVisionInput): diff --git a/src/flash/image/embedding/transforms/vissl_transforms.py b/src/flash/image/embedding/transforms/vissl_transforms.py index 048d0ddfdb..bb60ec5c86 100644 --- a/src/flash/image/embedding/transforms/vissl_transforms.py +++ b/src/flash/image/embedding/transforms/vissl_transforms.py @@ -32,7 +32,7 @@ def simclr_transform( collate_fn: Callable = simclr_collate_fn, ) -> partial: """For simclr and barlow twins.""" - transform = partial( + return partial( StandardMultiCropSSLTransform, total_num_crops=total_num_crops, num_crops=num_crops, @@ -44,8 +44,6 @@ def simclr_transform( collate_fn=collate_fn, ) - return transform - def swav_transform( total_num_crops: int = 8, @@ -58,7 +56,7 @@ def swav_transform( collate_fn: Callable = multicrop_collate_fn, ) -> partial: """For swav.""" - transform = partial( + return partial( StandardMultiCropSSLTransform, total_num_crops=total_num_crops, num_crops=num_crops, @@ -70,8 +68,6 @@ def swav_transform( collate_fn=collate_fn, ) - return transform - barlow_twins_transform = partial(simclr_transform, collate_fn=simclr_collate_fn) diff --git a/src/flash/image/embedding/vissl/adapter.py b/src/flash/image/embedding/vissl/adapter.py index b059463430..96f192ab89 100644 --- a/src/flash/image/embedding/vissl/adapter.py +++ b/src/flash/image/embedding/vissl/adapter.py @@ -161,7 +161,7 @@ def on_epoch_start(self) -> None: @staticmethod def get_model_config_template(): - cfg = AttrDict( + return AttrDict( { "BASE_MODEL_NAME": "multi_input_output_model", "SINGLE_PASS_EVERY_CROP": False, @@ -192,8 +192,6 @@ def get_model_config_template(): } ) - return cfg - def ssl_forward(self, batch) -> Any: model_output = self.vissl_base_model(batch) @@ -211,9 +209,7 @@ def shared_step(self, batch: Any, train: bool = True) -> Any: for hook in self.hooks: hook.on_forward(self.vissl_task) - loss = self.loss_fn(out, target=None) - - return loss + return self.loss_fn(out, target=None) def training_step(self, batch: Any, batch_idx: int) -> Any: loss = self.shared_step(batch) diff --git a/src/flash/pointcloud/detection/open3d_ml/input.py b/src/flash/pointcloud/detection/open3d_ml/input.py index 30d9d317ff..50e8d77a78 100644 --- a/src/flash/pointcloud/detection/open3d_ml/input.py +++ b/src/flash/pointcloud/detection/open3d_ml/input.py @@ -153,6 +153,7 @@ def predict_load_data(self, data, dataset: Input): return self.load_files(data, dataset) if isinstance(data, str) and isdir(data): raise NotImplementedError + return None def predict_load_sample(self, metadata: Dict[str, str]): metadata, attr = self.load_sample(metadata, has_label=False) diff --git a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py index 65af4e1315..6f7a4fcc53 100644 --- a/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py +++ b/src/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -147,10 +147,9 @@ def get_data(self, idx): points = DataProcessing.load_pc_kitti(pc_path) folder, file = split(pc_path) - if self.predicting: - label_path = join(folder, file[:-4] + ".label") - else: - label_path = join(folder, "../labels", file[:-4] + ".label") + label_path = ( + join(folder, file[:-4] + ".label") if self.predicting else join(folder, "../labels", file[:-4] + ".label") + ) if not exists(label_path): labels = np.zeros(np.shape(points)[0], dtype=np.int32) if self.split not in ["test", "all"]: @@ -159,14 +158,12 @@ def get_data(self, idx): else: labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32) - data = { + return { "point": points[:, 0:3], "feat": None, "label": labels, } - return data - def get_attr(self, idx): pc_path = self.path_list[idx] folder, file = split(pc_path) @@ -174,8 +171,7 @@ def get_attr(self, idx): name = f"{seq}_{file[:-4]}" pc_path = str(pc_path) - attr = {"idx": idx, "name": name, "path": pc_path, "split": self.split} - return attr + return {"idx": idx, "name": name, "path": pc_path, "split": self.split} def __len__(self): return len(self.path_list) diff --git a/src/flash/tabular/classification/input.py b/src/flash/tabular/classification/input.py index a393f2f2a9..84ccc40280 100644 --- a/src/flash/tabular/classification/input.py +++ b/src/flash/tabular/classification/input.py @@ -43,8 +43,8 @@ def load_data( targets = resolve_targets(data_frame, target_fields) self.load_target_metadata(targets, target_formatter=target_formatter) return [{DataKeys.INPUT: (c, n), DataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, targets)] - else: - return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] + + return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] def load_sample(self, sample: Dict[str, Any]) -> Any: if DataKeys.TARGET in sample: @@ -66,6 +66,7 @@ def load_data( return super().load_data( load_data_frame(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter ) + return None class TabularClassificationDictInput(TabularClassificationDataFrameInput): diff --git a/src/flash/tabular/classification/model.py b/src/flash/tabular/classification/model.py index 78c64fc970..15f83a6676 100644 --- a/src/flash/tabular/classification/model.py +++ b/src/flash/tabular/classification/model.py @@ -140,7 +140,7 @@ def _ci_benchmark_fn(history: List[Dict[str, Any]]): @classmethod def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": - model = cls( + return cls( parameters=datamodule.parameters, embedding_sizes=datamodule.embedding_sizes, cat_dims=datamodule.cat_dims, @@ -148,7 +148,6 @@ def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": num_classes=datamodule.num_classes, **kwargs, ) - return model @requires("serve") def serve( diff --git a/src/flash/tabular/classification/utils.py b/src/flash/tabular/classification/utils.py index e0fa50dade..9c208e9105 100644 --- a/src/flash/tabular/classification/utils.py +++ b/src/flash/tabular/classification/utils.py @@ -56,9 +56,7 @@ def _generate_codes(df: DataFrame, cat_cols: List) -> dict: tmp[col] = tmp[col].astype("category").cat.as_ordered() # list of categories for each column (always a column for None) - codes = {col: list(tmp[col].cat.categories) for col in cat_cols} - - return codes + return {col: list(tmp[col].cat.categories) for col in cat_cols} def _categorize(df: DataFrame, cat_cols: List, codes) -> DataFrame: diff --git a/src/flash/tabular/regression/input.py b/src/flash/tabular/regression/input.py index a673f19ee4..37091ffeaf 100644 --- a/src/flash/tabular/regression/input.py +++ b/src/flash/tabular/regression/input.py @@ -40,8 +40,8 @@ def load_data( if not self.predicting: targets = data_frame[target_field].to_numpy().astype(np.float32) return [{DataKeys.INPUT: (c, n), DataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, targets)] - else: - return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] + + return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] class TabularRegressionCSVInput(TabularRegressionDataFrameInput): @@ -57,6 +57,7 @@ def load_data( return super().load_data( load_data_frame(file), categorical_fields, numerical_fields, target_field, parameters ) + return None class TabularRegressionDictInput(TabularRegressionDataFrameInput): diff --git a/src/flash/tabular/regression/model.py b/src/flash/tabular/regression/model.py index 531cacd5d8..cbbac0f6da 100644 --- a/src/flash/tabular/regression/model.py +++ b/src/flash/tabular/regression/model.py @@ -130,14 +130,13 @@ def data_parameters(self) -> Dict[str, Any]: @classmethod def from_data(cls, datamodule, **kwargs) -> "TabularRegressor": - model = cls( + return cls( parameters=datamodule.parameters, embedding_sizes=datamodule.embedding_sizes, cat_dims=datamodule.cat_dims, num_features=datamodule.num_features, **kwargs ) - return model @requires("serve") def serve( diff --git a/src/flash/video/classification/input.py b/src/flash/video/classification/input.py index 5c2fb58d36..2d8b42ab54 100644 --- a/src/flash/video/classification/input.py +++ b/src/flash/video/classification/input.py @@ -393,18 +393,16 @@ class VideoClassificationTensorsPredictInput(Input): def predict_load_data(self, data: Union[torch.Tensor, List[Any], Any]): if _is_list_like(data): return data - else: - if not isinstance(data, torch.Tensor): - raise TypeError(f"Expected either a list/tuple of torch.Tensor or torch.Tensor, but got: {type(data)}.") - if data.ndim == 5: - return list(data) - elif data.ndim == 4: - return [data] - else: - raise ValueError( - f"Got dimension of the input tensor: {data.ndim}," - " for stack of tensors - dimension should be 5 or for a single tensor, dimension should be 4." - ) + if not isinstance(data, torch.Tensor): + raise TypeError(f"Expected either a list/tuple of torch.Tensor or torch.Tensor, but got: {type(data)}.") + if data.ndim == 5: + return list(data) + if data.ndim == 4: + return [data] + raise ValueError( + f"Got dimension of the input tensor: {data.ndim}," + " for stack of tensors - dimension should be 5 or for a single tensor, dimension should be 4." + ) def predict_load_sample(self, sample: torch.Tensor) -> Dict[str, Any]: return { diff --git a/src/flash/video/classification/utils.py b/src/flash/video/classification/utils.py index b585700ba1..1c8cd2526e 100644 --- a/src/flash/video/classification/utils.py +++ b/src/flash/video/classification/utils.py @@ -61,7 +61,7 @@ def __next__(self) -> dict: video_tensor, info_dict = self._labeled_videos[video_index] self._loaded_video_label = (video_tensor, info_dict, video_index) - sample_dict = { + return { "video": self._loaded_video_label[0], "video_name": f"video{video_index}", "video_index": video_index, @@ -69,8 +69,6 @@ def __next__(self) -> dict: "video_label": info_dict, } - return sample_dict - def __iter__(self): self._video_sampler_iter = None # Reset video sampler diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index d7f86a6da5..96505cde7a 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -46,8 +46,7 @@ def loss(self, batch, prediction): def step(self, x): x = self(x) - out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) - return out + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) def training_step(self, batch, batch_idx): output = self(batch) diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 3b8f3bd49a..4e1164788e 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -49,7 +49,7 @@ def simple_datamodule(tmpdir): _rand_image(image_size).save(pb_2) n = 10 - dm = ImageClassificationData.from_files( + return ImageClassificationData.from_files( train_files=[str(pa_1)] * n + [str(pa_2)] * n + [str(pb_1)] * n + [str(pb_2)] * n, train_targets=[0] * n + [1] * n + [2] * n + [3] * n, test_files=[str(pa_1)] * n, @@ -58,7 +58,6 @@ def simple_datamodule(tmpdir): num_workers=0, transform_kwargs={"image_size": image_size}, ) - return dm @pytest.mark.skipif( diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 7bc24c4e31..18703fe951 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -109,9 +109,7 @@ def _create_synth_folders_dataset(tmpdir): Image.new("RGB", (224, 224)).save(predict / "images" / "sample_one.png") Image.new("RGB", (224, 224)).save(predict / "images" / "sample_two.png") - predict_folder = os.fspath(Path(predict / "images")) - - return predict_folder + return os.fspath(Path(predict / "images")) def _create_synth_files_dataset(tmpdir): From a97ad5360111e5e2154770c3ca79db41ea0e7b1a Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 21:06:10 +0200 Subject: [PATCH 37/39] ruff: enable SIM (#1555) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 + src/flash/audio/classification/input.py | 5 +-- src/flash/core/data/data_module.py | 5 +-- src/flash/core/data/io/input_transform.py | 5 +-- src/flash/core/data/splits.py | 5 +-- .../integrations/pytorch_tabular/backbones.py | 2 +- .../core/integrations/transformers/collate.py | 2 +- src/flash/core/model.py | 34 ++++++++++--------- src/flash/core/optimizers/lamb.py | 6 ++-- src/flash/core/optimizers/lars.py | 16 ++++----- src/flash/core/registry.py | 10 ++---- src/flash/core/serve/component.py | 2 +- src/flash/core/serve/composition.py | 4 +-- src/flash/core/serve/core.py | 5 +-- src/flash/core/serve/dag/optimization.py | 5 +-- src/flash/core/serve/dag/order.py | 2 +- src/flash/core/serve/dag/utils.py | 5 +-- src/flash/core/serve/execution.py | 2 +- src/flash/core/utilities/flash_cli.py | 22 +++++++----- src/flash/graph/backbones.py | 2 +- .../image/classification/backbones/resnet.py | 2 +- src/flash/image/classification/data.py | 5 +-- src/flash/image/classification/input.py | 5 +-- .../classification/integrations/baal/loop.py | 5 ++- src/flash/image/detection/data.py | 5 +-- src/flash/image/detection/output.py | 5 +-- src/flash/image/embedding/model.py | 9 +++-- .../image/face_detection/input_transform.py | 2 +- .../image/keypoint_detection/backbones.py | 19 +++++------ src/flash/text/classification/adapters.py | 2 +- src/flash/text/question_answering/input.py | 15 ++++---- src/flash/video/classification/data.py | 5 +-- tests/audio/classification/test_model.py | 8 ++--- tests/core/serve/test_types/test_number.py | 2 +- tests/core/serve/test_types/test_text.py | 4 +-- tests/core/utilities/test_lightning_cli.py | 2 +- tests/helpers/task_tester.py | 8 ++--- .../test_training_strategies.py | 2 +- tests/image/face_detection/test_model.py | 8 ++--- tests/video/classification/test_model.py | 32 ++++++++--------- 40 files changed, 120 insertions(+), 165 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0c6b5f6da1..1639f58e29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ select = [ ] extend-select = [ "C4", # see: https://pypi.org/project/flake8-comprehensions + "SIM", # see: https://pypi.org/project/flake8-simplify "RET", # see: https://pypi.org/project/flake8-return ] ignore = [ diff --git a/src/flash/audio/classification/input.py b/src/flash/audio/classification/input.py index 812a924d4d..d0174fdb57 100644 --- a/src/flash/audio/classification/input.py +++ b/src/flash/audio/classification/input.py @@ -130,10 +130,7 @@ def load_data( target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: files = resolve_files(data_frame, input_key, root, resolver) - if target_keys is not None: - targets = resolve_targets(data_frame, target_keys) - else: - targets = None + targets = resolve_targets(data_frame, target_keys) if target_keys is not None else None result = super().load_data( files, targets, sampling_rate=sampling_rate, n_fft=n_fft, target_formatter=target_formatter ) diff --git a/src/flash/core/data/data_module.py b/src/flash/core/data/data_module.py index 99f805b58d..849ff061c3 100644 --- a/src/flash/core/data/data_module.py +++ b/src/flash/core/data/data_module.py @@ -215,10 +215,7 @@ def _train_dataloader(self) -> DataLoader: input_transform = self._resolve_input_transform() shuffle: bool = False - if isinstance(train_ds, IterableDataset): - drop_last = False - else: - drop_last = len(train_ds) > self.batch_size + drop_last = False if isinstance(train_ds, IterableDataset) else len(train_ds) > self.batch_size if self.sampler is None: sampler = None diff --git a/src/flash/core/data/io/input_transform.py b/src/flash/core/data/io/input_transform.py index 49d754a8df..f65af3d78c 100644 --- a/src/flash/core/data/io/input_transform.py +++ b/src/flash/core/data/io/input_transform.py @@ -774,10 +774,7 @@ def __call__(self, samples: Sequence[Any]) -> Any: self.callback.on_load_sample(sample, self.stage) if self.apply_per_sample_transform: - if not isinstance(samples, list): - list_samples = [samples] - else: - list_samples = samples + list_samples = [samples] if not isinstance(samples, list) else samples transformed_samples = [self.per_sample_transform(sample, self.stage) for sample in list_samples] diff --git a/src/flash/core/data/splits.py b/src/flash/core/data/splits.py index 8862b70152..a51e29fade 100644 --- a/src/flash/core/data/splits.py +++ b/src/flash/core/data/splits.py @@ -40,10 +40,7 @@ def __init__( if not isinstance(indices, list): raise TypeError("indices should be a list") - if use_duplicated_indices: - indices = list(indices) - else: - indices = list(np.unique(indices)) + indices = list(indices) if use_duplicated_indices else list(np.unique(indices)) if np.max(indices) >= len(dataset) or np.min(indices) < 0: raise ValueError(f"`indices` should be within [0, {len(dataset) -1}].") diff --git a/src/flash/core/integrations/pytorch_tabular/backbones.py b/src/flash/core/integrations/pytorch_tabular/backbones.py index 395d1051e8..72084ae0d8 100644 --- a/src/flash/core/integrations/pytorch_tabular/backbones.py +++ b/src/flash/core/integrations/pytorch_tabular/backbones.py @@ -51,7 +51,7 @@ def _read_parse_config(config, cls): **{ k: v for k, v in _config.items() - if (k in cls.__dataclass_fields__.keys()) and cls.__dataclass_fields__[k].init + if (k in cls.__dataclass_fields__) and cls.__dataclass_fields__[k].init } ) else: diff --git a/src/flash/core/integrations/transformers/collate.py b/src/flash/core/integrations/transformers/collate.py index b78f2aa87a..43c764325e 100644 --- a/src/flash/core/integrations/transformers/collate.py +++ b/src/flash/core/integrations/transformers/collate.py @@ -46,4 +46,4 @@ def tokenize(self, sample): raise NotImplementedError def __call__(self, samples): - return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0].keys()})) + return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0]})) diff --git a/src/flash/core/model.py b/src/flash/core/model.py index fd3e3dc86c..3bdd199f3b 100644 --- a/src/flash/core/model.py +++ b/src/flash/core/model.py @@ -84,10 +84,9 @@ def __setattr__(self, key, value): if isinstance(value, (LightningModule, ModuleWrapperBase)): self._children.append(key) patched_attributes = ["_current_fx_name", "_current_hook_fx_name", "_results", "_data_pipeline_state"] - if isinstance(value, Trainer) or key in patched_attributes: - if hasattr(self, "_children"): - for child in self._children: - setattr(getattr(self, child), key, value) + if (isinstance(value, Trainer) or key in patched_attributes) and hasattr(self, "_children"): + for child in self._children: + setattr(getattr(self, child), key, value) super().__setattr__(key, value) @@ -773,17 +772,20 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: ) # Providers part - if lr_scheduler_metadata is not None and "providers" in lr_scheduler_metadata.keys(): - if lr_scheduler_metadata["providers"] == _HUGGINGFACE: - if lr_scheduler_data["name"] != "constant_schedule": - num_training_steps: int = self.get_num_training_steps() - num_warmup_steps: int = self._compute_warmup( - num_training_steps=num_training_steps, - num_warmup_steps=lr_scheduler_kwargs["num_warmup_steps"], - ) - lr_scheduler_kwargs["num_warmup_steps"] = num_warmup_steps - if lr_scheduler_data["name"] != "constant_schedule_with_warmup": - lr_scheduler_kwargs["num_training_steps"] = num_training_steps + if ( + lr_scheduler_metadata is not None + and "providers" in lr_scheduler_metadata + and lr_scheduler_metadata["providers"] == _HUGGINGFACE + and lr_scheduler_data["name"] != "constant_schedule" + ): + num_training_steps: int = self.get_num_training_steps() + num_warmup_steps: int = self._compute_warmup( + num_training_steps=num_training_steps, + num_warmup_steps=lr_scheduler_kwargs["num_warmup_steps"], + ) + lr_scheduler_kwargs["num_warmup_steps"] = num_warmup_steps + if lr_scheduler_data["name"] != "constant_schedule_with_warmup": + lr_scheduler_kwargs["num_training_steps"] = num_training_steps # User can register a callable that returns a lr_scheduler_config # 1) If return value is an instance of _LR_Scheduler -> Add to current config and return the config. @@ -792,7 +794,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: if isinstance(lr_scheduler, Dict): dummy_config = default_scheduler_config - if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()): + if not all(config_key in dummy_config for config_key in lr_scheduler): raise ValueError( f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" f" configuration with keys belonging to {list(dummy_config.keys())}." diff --git a/src/flash/core/optimizers/lamb.py b/src/flash/core/optimizers/lamb.py index 853d3f8668..d0b07e2615 100644 --- a/src/flash/core/optimizers/lamb.py +++ b/src/flash/core/optimizers/lamb.py @@ -72,15 +72,15 @@ def __init__( exclude_from_layer_adaptation: bool = False, amsgrad: bool = False, ): - if not 0.0 <= lr: + if not lr >= 0.0: raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= eps: + if not eps >= 0.0: raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") - if not 0.0 <= weight_decay: + if not weight_decay >= 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = { "lr": lr, diff --git a/src/flash/core/optimizers/lars.py b/src/flash/core/optimizers/lars.py index e1f0e43659..f89f2cba0b 100644 --- a/src/flash/core/optimizers/lars.py +++ b/src/flash/core/optimizers/lars.py @@ -142,13 +142,12 @@ def step(self, closure=None): g_norm = torch.norm(p.grad.data) # lars scaling + weight decay part - if weight_decay != 0: - if p_norm != 0 and g_norm != 0: - lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps) - lars_lr *= self.trust_coefficient + if weight_decay != 0 and p_norm != 0 and g_norm != 0: + lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps) + lars_lr *= self.trust_coefficient - d_p = d_p.add(p, alpha=weight_decay) - d_p *= lars_lr + d_p = d_p.add(p, alpha=weight_decay) + d_p *= lars_lr # sgd part if momentum != 0: @@ -158,10 +157,7 @@ def step(self, closure=None): else: buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(d_p, alpha=1 - dampening) - if nesterov: - d_p = d_p.add(buf, alpha=momentum) - else: - d_p = buf + d_p = d_p.add(buf, alpha=momentum) if nesterov else buf p.add_(d_p, alpha=-group["lr"]) diff --git a/src/flash/core/registry.py b/src/flash/core/registry.py index 4b9fdb28da..b968ee1934 100644 --- a/src/flash/core/registry.py +++ b/src/flash/core/registry.py @@ -65,10 +65,7 @@ def __add__(self, other): else: registries += [self] - if isinstance(other, ConcatRegistry): - registries = other.registries + tuple(registries) - else: - registries = [other] + registries + registries = other.registries + tuple(registries) if isinstance(other, ConcatRegistry) else [other] + registries return ConcatRegistry(*registries) @@ -122,10 +119,7 @@ def _register_function( raise TypeError(f"You can only register a callable, found: {fn}") if name is None: - if hasattr(fn, "func"): - name = fn.func.__name__ - else: - name = fn.__name__ + name = fn.func.__name__ if hasattr(fn, "func") else fn.__name__ if self._verbose: rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}") diff --git a/src/flash/core/serve/component.py b/src/flash/core/serve/component.py index d7c6a9ee92..c267ea49b8 100644 --- a/src/flash/core/serve/component.py +++ b/src/flash/core/serve/component.py @@ -97,7 +97,7 @@ def _validate_model_args( if not all(isinstance(x, _Servable_t) for x in args): raise TypeError(f"One of arg in args={args} is not type {_Servable_t}") elif isinstance(args, dict): - if not all(isinstance(x, str) for x in args.keys()): + if not all(isinstance(x, str) for x in args): raise TypeError(f"One of keys in args={args.keys()} is not type {str}") if not all(isinstance(x, _Servable_t) for x in args.values()): raise TypeError(f"One of values in args={args} is not type {_Servable_t}") diff --git a/src/flash/core/serve/composition.py b/src/flash/core/serve/composition.py index 213e0e18de..d627c02995 100644 --- a/src/flash/core/serve/composition.py +++ b/src/flash/core/serve/composition.py @@ -93,8 +93,8 @@ def __init__( if len(self._name_endpoints) == 0: comp = first(self.components.values()) # one element iterable ep_route = f"/{comp._flashserve_meta_.exposed.__name__}" - ep_inputs = {k: f"{comp.uid}.inputs.{k}" for k in asdict(comp.inputs).keys()} - ep_outputs = {k: f"{comp.uid}.outputs.{k}" for k in asdict(comp.outputs).keys()} + ep_inputs = {k: f"{comp.uid}.inputs.{k}" for k in asdict(comp.inputs)} + ep_outputs = {k: f"{comp.uid}.outputs.{k}" for k in asdict(comp.outputs)} ep = Endpoint(route=ep_route, inputs=ep_inputs, outputs=ep_outputs) self._name_endpoints[f"{comp._flashserve_meta_.exposed.__name__}_ENDPOINT"] = ep diff --git a/src/flash/core/serve/core.py b/src/flash/core/serve/core.py index 6ebe5bb415..c55be4641f 100644 --- a/src/flash/core/serve/core.py +++ b/src/flash/core/serve/core.py @@ -116,10 +116,7 @@ def __init__( raise parsed = [script_loader_cls, parse_obj_as(Union[HttpUrl, FilePath], loc)] - if isinstance(parsed[-1], Path): - f_path = loc - else: - f_path = download_file(loc, download_path=download_path) + f_path = loc if isinstance(parsed[-1], Path) else download_file(loc, download_path=download_path) if len(args) == 2 and args[0].__qualname__ != script_loader_cls.__qualname__: # if this is a class and path/url... diff --git a/src/flash/core/serve/dag/optimization.py b/src/flash/core/serve/dag/optimization.py index 6ae9008a0b..4ab3a07ef2 100644 --- a/src/flash/core/serve/dag/optimization.py +++ b/src/flash/core/serve/dag/optimization.py @@ -493,10 +493,7 @@ def fuse( key_renamer = rename_keys rename_keys = key_renamer is not None - if dependencies is None: - deps = {k: get_dependencies(dsk, k, as_list=True) for k in dsk} - else: - deps = dict(dependencies) + deps = {k: get_dependencies(dsk, k, as_list=True) for k in dsk} if dependencies is None else dict(dependencies) rdeps = {} for k, vals in deps.items(): diff --git a/src/flash/core/serve/dag/order.py b/src/flash/core/serve/dag/order.py index 6d4ae93688..fff934f783 100644 --- a/src/flash/core/serve/dag/order.py +++ b/src/flash/core/serve/dag/order.py @@ -366,7 +366,7 @@ def finish_now_key(x): if now_keys: # Run before `inner_stack` (change tactical goal!) inner_stacks_append(inner_stack) - if 1 < len(now_keys): + if len(now_keys) > 1: now_keys.sort(reverse=True) for key in now_keys: pool = dep_pools[key] diff --git a/src/flash/core/serve/dag/utils.py b/src/flash/core/serve/dag/utils.py index c7e8925a6d..e90699cbae 100644 --- a/src/flash/core/serve/dag/utils.py +++ b/src/flash/core/serve/dag/utils.py @@ -80,10 +80,7 @@ def key_split(s): s = s[0] try: words = s.split("-") - if not words[0][0].isalpha(): - result = words[0].strip("_'()\"") - else: - result = words[0] + result = words[0].strip("_'()\"") if not words[0][0].isalpha() else words[0] for word in words[1:]: if word.isalpha() and not (len(word) == 8 and hex_pattern.match(word) is not None): result += "-" + word diff --git a/src/flash/core/serve/execution.py b/src/flash/core/serve/execution.py index e5ec129115..3b555660e0 100644 --- a/src/flash/core/serve/execution.py +++ b/src/flash/core/serve/execution.py @@ -267,7 +267,7 @@ def build_composition( dsk_tgt_src_connections[target_dsk] = (identity, source_dsk) rewrite_ruleset = RuleSet() - for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk.keys(): + for dsk_payload_target_serial in initial_task_dsk.payload_tasks_dsk: dsk_payload_target, _serial_ident = dsk_payload_target_serial.rsplit(".", maxsplit=1) if _serial_ident != "serial": raise RuntimeError( diff --git a/src/flash/core/utilities/flash_cli.py b/src/flash/core/utilities/flash_cli.py index 5ceb592a32..1de8f5f9df 100644 --- a/src/flash/core/utilities/flash_cli.py +++ b/src/flash/core/utilities/flash_cli.py @@ -192,14 +192,16 @@ def parse_arguments(self) -> None: def add_arguments_to_parser(self, parser) -> None: subcommands = parser.add_subcommands() - for function in vars(self.local_datamodule_class).keys(): + for function in vars(self.local_datamodule_class): if not function.startswith("from"): continue - if ( - hasattr(DataModule, function) and is_overridden(function, self.local_datamodule_class, DataModule) - ) or not hasattr(DataModule, function): - if getattr(self.local_datamodule_class, function, None) is not None: - self.add_subcommand_from_function(subcommands, getattr(self.local_datamodule_class, function)) + _data_overwritten = hasattr(DataModule, function) and is_overridden( + function, self.local_datamodule_class, DataModule + ) + if (_data_overwritten or not hasattr(DataModule, function)) and getattr( + self.local_datamodule_class, function, None + ) is not None: + self.add_subcommand_from_function(subcommands, getattr(self.local_datamodule_class, function)) for datamodule_builder in self.additional_datamodule_builders: self.add_subcommand_from_function(subcommands, datamodule_builder) @@ -243,9 +245,11 @@ def instantiate_classes(self) -> None: self.datamodule = self._subcommand_builders[sub_config](**self.config.get(sub_config)) for datamodule_attribute in self.datamodule_attributes: - if datamodule_attribute in self.config["model"]: - if getattr(self.datamodule, datamodule_attribute, None) is not None: - self.config["model"][datamodule_attribute] = getattr(self.datamodule, datamodule_attribute) + if ( + datamodule_attribute in self.config["model"] + and getattr(self.datamodule, datamodule_attribute, None) is not None + ): + self.config["model"][datamodule_attribute] = getattr(self.datamodule, datamodule_attribute) self.config_init = self.parser.instantiate_classes(self.config) self.model = self.config_init["model"] self.instantiate_trainer() diff --git a/src/flash/graph/backbones.py b/src/flash/graph/backbones.py index d2186463d3..f7f23361f7 100644 --- a/src/flash/graph/backbones.py +++ b/src/flash/graph/backbones.py @@ -37,5 +37,5 @@ def _load_graph_backbone( return model(in_channels, hidden_channels, num_layers) -for model_name in MODELS.keys(): +for model_name in MODELS: GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name)) diff --git a/src/flash/image/classification/backbones/resnet.py b/src/flash/image/classification/backbones/resnet.py index a3e1a21f73..0a5e0c6edb 100644 --- a/src/flash/image/classification/backbones/resnet.py +++ b/src/flash/image/classification/backbones/resnet.py @@ -338,7 +338,7 @@ def _resnet( weights_paths[pretrained], map_location=torch.device("cpu") if device == -1 else torch.device(device) ) - if "classy_state_dict" in model_weights.keys(): + if "classy_state_dict" in model_weights: model_weights = model_weights["classy_state_dict"]["base_model"]["model"]["trunk"] model_weights = { key.replace("_feature_blocks.", "") if "_feature_blocks." in key else key: val diff --git a/src/flash/image/classification/data.py b/src/flash/image/classification/data.py index e57065fa4a..a4e54e6d18 100644 --- a/src/flash/image/classification/data.py +++ b/src/flash/image/classification/data.py @@ -46,10 +46,7 @@ ) from flash.image.classification.input_transform import ImageClassificationInputTransform -if _FIFTYONE_AVAILABLE: - SampleCollection = "fiftyone.core.collections.SampleCollection" -else: - SampleCollection = None +SampleCollection = "fiftyone.core.collections.SampleCollection" if _FIFTYONE_AVAILABLE else None if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt diff --git a/src/flash/image/classification/input.py b/src/flash/image/classification/input.py index d770780f0f..7991f595a4 100644 --- a/src/flash/image/classification/input.py +++ b/src/flash/image/classification/input.py @@ -151,10 +151,7 @@ def load_data( target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: files = resolve_files(data_frame, input_key, root, resolver) - if target_keys is not None: - targets = resolve_targets(data_frame, target_keys) - else: - targets = None + targets = resolve_targets(data_frame, target_keys) if target_keys is not None else None result = super().load_data(files, targets, target_formatter=target_formatter) # If we had binary multi-class targets then we also know the labels (column names) diff --git a/src/flash/image/classification/integrations/baal/loop.py b/src/flash/image/classification/integrations/baal/loop.py index 451d364309..61785b668a 100644 --- a/src/flash/image/classification/integrations/baal/loop.py +++ b/src/flash/image/classification/integrations/baal/loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from copy import deepcopy from typing import Any, Dict, Optional @@ -182,10 +183,8 @@ def _reset_dataloader_for_stage(self, running_state: RunningStage): ) setattr(self.trainer, dataloader_name, None) # TODO: Resolve this within PyTorch Lightning. - try: + with contextlib.suppress(Exception): getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module) - except Exception: - pass def _teardown(self) -> None: self.trainer.train_dataloader = None diff --git a/src/flash/image/detection/data.py b/src/flash/image/detection/data.py index 2f5bd3f0cc..9884ab9889 100644 --- a/src/flash/image/detection/data.py +++ b/src/flash/image/detection/data.py @@ -40,10 +40,7 @@ ObjectDetectionTensorInput, ) -if _FIFTYONE_AVAILABLE: - SampleCollection = "fiftyone.core.collections.SampleCollection" -else: - SampleCollection = None +SampleCollection = "fiftyone.core.collections.SampleCollection" if _FIFTYONE_AVAILABLE else None if _ICEVISION_AVAILABLE: from icevision.core import ClassMap diff --git a/src/flash/image/detection/output.py b/src/flash/image/detection/output.py index f5a8d61a4d..a5811c9e9a 100644 --- a/src/flash/image/detection/output.py +++ b/src/flash/image/detection/output.py @@ -84,10 +84,7 @@ def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] ] label = label.item() - if self._labels is not None: - label = self._labels[label] - else: - label = str(int(label)) + label = self._labels[label] if self._labels is not None else str(int(label)) detections.append( fo.Detection( diff --git a/src/flash/image/embedding/model.py b/src/flash/image/embedding/model.py index a494f0d567..2f02beafee 100644 --- a/src/flash/image/embedding/model.py +++ b/src/flash/image/embedding/model.py @@ -135,9 +135,12 @@ def __init__( ) self.input_transform = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) - if "providers" in metadata["metadata"] and metadata["metadata"]["providers"].name == "Facebook Research/vissl": - if pretraining_transform is None: - raise ValueError("Correct pretraining_transform must be set to use VISSL") + if ( + "providers" in metadata["metadata"] + and metadata["metadata"]["providers"].name == "Facebook Research/vissl" + and pretraining_transform is None + ): + raise ValueError("Correct pretraining_transform must be set to use VISSL") def forward(self, x: Tensor) -> Any: return self.model(x) diff --git a/src/flash/image/face_detection/input_transform.py b/src/flash/image/face_detection/input_transform.py index f9d5b052b4..9a889bd414 100644 --- a/src/flash/image/face_detection/input_transform.py +++ b/src/flash/image/face_detection/input_transform.py @@ -41,7 +41,7 @@ def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence samples["scales"] = scales samples["paddings"] = paddings - if DataKeys.TARGET in samples.keys(): + if DataKeys.TARGET in samples: targets = samples[DataKeys.TARGET] for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): diff --git a/src/flash/image/keypoint_detection/backbones.py b/src/flash/image/keypoint_detection/backbones.py index 0df353f5dd..2f619ed005 100644 --- a/src/flash/image/keypoint_detection/backbones.py +++ b/src/flash/image/keypoint_detection/backbones.py @@ -56,13 +56,12 @@ def from_task( ) -if _ICEVISION_AVAILABLE: - if _TORCHVISION_AVAILABLE: - model_type = icevision_models.torchvision.keypoint_rcnn - KEYPOINT_DETECTION_HEADS( - partial(load_icevision_ignore_image_size, model_type), - model_type.__name__.split(".")[-1], - backbones=get_backbones(model_type), - adapter=IceVisionKeypointDetectionAdapter, - providers=[_ICEVISION, _TORCHVISION], - ) +if _ICEVISION_AVAILABLE and _TORCHVISION_AVAILABLE: + model_type = icevision_models.torchvision.keypoint_rcnn + KEYPOINT_DETECTION_HEADS( + partial(load_icevision_ignore_image_size, model_type), + model_type.__name__.split(".")[-1], + backbones=get_backbones(model_type), + adapter=IceVisionKeypointDetectionAdapter, + providers=[_ICEVISION, _TORCHVISION], + ) diff --git a/src/flash/text/classification/adapters.py b/src/flash/text/classification/adapters.py index 517dd85da9..916db53e77 100644 --- a/src/flash/text/classification/adapters.py +++ b/src/flash/text/classification/adapters.py @@ -105,7 +105,7 @@ def tokenize(self, sample): return sample def __call__(self, samples): - return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0].keys()})) + return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0]})) class GenericAdapter(Adapter): diff --git a/src/flash/text/question_answering/input.py b/src/flash/text/question_answering/input.py index b587a3346b..9381ddf0b0 100644 --- a/src/flash/text/question_answering/input.py +++ b/src/flash/text/question_answering/input.py @@ -53,15 +53,14 @@ def load_data( column_names = hf_dataset.column_names if self.training or self.validating or self.testing: - if answer_column_name == "answer": - if "answer" not in column_names: - if "answer_text" in column_names and "answer_start" in column_names: - hf_dataset = hf_dataset.map(self._reshape_answer_column, batched=False) - else: - raise KeyError( - """Dataset must contain either \"answer\" key as dict type or "answer_text" and + if answer_column_name == "answer" and "answer" not in column_names: + if "answer_text" in column_names and "answer_start" in column_names: + hf_dataset = hf_dataset.map(self._reshape_answer_column, batched=False) + else: + raise KeyError( + """Dataset must contain either \"answer\" key as dict type or "answer_text" and "answer_start" as string and integer types.""" - ) + ) if not isinstance(hf_dataset[answer_column_name][0], Dict): raise TypeError( f'{answer_column_name} column should be of type dict with keys "text" and "answer_start"' diff --git a/src/flash/video/classification/data.py b/src/flash/video/classification/data.py index 531b2c6615..68351013e7 100644 --- a/src/flash/video/classification/data.py +++ b/src/flash/video/classification/data.py @@ -40,10 +40,7 @@ ) from flash.video.classification.input_transform import VideoClassificationInputTransform -if _FIFTYONE_AVAILABLE: - SampleCollection = "fiftyone.core.collections.SampleCollection" -else: - SampleCollection = None +SampleCollection = "fiftyone.core.collections.SampleCollection" if _FIFTYONE_AVAILABLE else None if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.data.clip_sampling import ClipSampler diff --git a/tests/audio/classification/test_model.py b/tests/audio/classification/test_model.py index 89bea0fb97..5b164dc912 100644 --- a/tests/audio/classification/test_model.py +++ b/tests/audio/classification/test_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from unittest.mock import patch import pytest @@ -23,8 +24,5 @@ @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") def test_cli(): cli_args = ["flash", "audio_classification", "--trainer.fast_dev_run", "True"] - with patch("sys.argv", cli_args): - try: - main() - except SystemExit: - pass + with patch("sys.argv", cli_args), contextlib.suppress(SystemExit): + main() diff --git a/tests/core/serve/test_types/test_number.py b/tests/core/serve/test_types/test_number.py index 2821fdc3d7..deefd8477d 100644 --- a/tests/core/serve/test_types/test_number.py +++ b/tests/core/serve/test_types/test_number.py @@ -9,7 +9,7 @@ def test_serialize(): num = Number() tensor = torch.tensor([[1]]) - assert 1 == num.serialize(tensor) + assert num.serialize(tensor) == 1 assert isinstance(num.serialize(tensor.to(torch.float32)), float) assert isinstance(num.serialize(tensor.to(torch.float64)), float) assert isinstance(num.serialize(tensor.to(torch.int16)), int) diff --git a/tests/core/serve/test_types/test_text.py b/tests/core/serve/test_types/test_text.py index 4aa3b7afd4..3a813b05db 100644 --- a/tests/core/serve/test_types/test_text.py +++ b/tests/core/serve/test_types/test_text.py @@ -23,8 +23,8 @@ def test_custom_tokenizer(): tokenizer = CustomTokenizer("test") text = Text(tokenizer=tokenizer) - assert "encoding from test" == text.deserialize("random string") - assert "decoding from test" == text.serialize(torch.tensor([[1, 2]])) + assert text.deserialize("random string") == "encoding from test" + assert text.serialize(torch.tensor([[1, 2]])) == "decoding from test" @pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="Not testing serve.") diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index 54b5a7fd56..8fcd81eadb 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -389,7 +389,7 @@ def test_lightning_cli_help(): assert "--data.help" in out.getvalue() skip_params = {"self"} - for param in inspect.signature(Trainer.__init__).parameters.keys(): + for param in inspect.signature(Trainer.__init__).parameters: if param not in skip_params: assert f"--trainer.{param}" in out.getvalue() diff --git a/tests/helpers/task_tester.py b/tests/helpers/task_tester.py index 497426af3f..97dbefb571 100644 --- a/tests/helpers/task_tester.py +++ b/tests/helpers/task_tester.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import inspect import os @@ -130,11 +131,8 @@ def _test_jit_script(self, tmpdir): def _test_cli(self, extra_args: List): """Tests that the default Flash zero configuration runs for the task.""" cli_args = ["flash", self.cli_command, "--trainer.fast_dev_run", "True"] + extra_args - with patch("sys.argv", cli_args): - try: - main() - except SystemExit: - pass + with patch("sys.argv", cli_args), contextlib.suppress(SystemExit): + main() def _test_load_from_checkpoint_dependency_error(self): diff --git a/tests/image/classification/test_training_strategies.py b/tests/image/classification/test_training_strategies.py index 0236188768..dd524dd722 100644 --- a/tests/image/classification/test_training_strategies.py +++ b/tests/image/classification/test_training_strategies.py @@ -109,7 +109,7 @@ def test_wrongly_specified_training_strategies(): ) -@pytest.mark.skipif(not os.getenv("FLASH_RUNNING_SPECIAL_TESTS", "0") == "1", reason="Should run with special test") +@pytest.mark.skipif(os.getenv("FLASH_RUNNING_SPECIAL_TESTS", "0") != "1", reason="Should run with special test") @pytest.mark.skipif(not _LEARN2LEARN_AVAILABLE, reason="image and learn2learn libraries aren't installed.") def test_learn2learn_training_strategies_ddp(tmpdir): _test_learn2learning_training_strategies(2, "prototypicalnetworks", tmpdir, strategy="ddp") diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py index d4f49a6777..a7da39c1dc 100644 --- a/tests/image/face_detection/test_model.py +++ b/tests/image/face_detection/test_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from unittest.mock import patch import pytest @@ -57,8 +58,5 @@ def test_fastface_backbones_registry(): @pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") def test_cli(): cli_args = ["flash", "face_detection", "--trainer.fast_dev_run", "True"] - with patch("sys.argv", cli_args): - try: - main() - except SystemExit: - pass + with patch("sys.argv", cli_args), contextlib.suppress(SystemExit): + main() diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index 9523e0bb19..d30f7d7d6e 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -99,20 +99,19 @@ def mock_video_data_frame(): with temp_encoded_video(num_frames=num_frames, fps=fps) as ( video_file_name_1, data_1, + ), temp_encoded_video(num_frames=num_frames, fps=fps) as ( + video_file_name_2, + data_2, ): - with temp_encoded_video(num_frames=num_frames, fps=fps) as ( - video_file_name_2, - data_2, - ): - data_frame = DataFrame.from_dict( - { - "file": [video_file_name_1, video_file_name_2, video_file_name_1, video_file_name_2], - "target": ["cat", "dog", "cat", "dog"], - } - ) + data_frame = DataFrame.from_dict( + { + "file": [video_file_name_1, video_file_name_2, video_file_name_1, video_file_name_2], + "target": ["cat", "dog", "cat", "dog"], + } + ) - video_duration = num_frames / fps - yield data_frame, video_duration + video_duration = num_frames / fps + yield data_frame, video_duration @contextlib.contextmanager @@ -136,10 +135,11 @@ def mock_encoded_video_dataset_folder(tmpdir): os.makedirs(str(tmp_dir / "c1")) os.makedirs(str(tmp_dir / "c2")) - with temp_encoded_video(num_frames=num_frames, fps=fps, directory=str(tmp_dir / "c1")): - with temp_encoded_video(num_frames=num_frames, fps=fps, directory=str(tmp_dir / "c2")): - video_duration = num_frames / fps - yield str(tmp_dir), video_duration + with temp_encoded_video(num_frames=num_frames, fps=fps, directory=str(tmp_dir / "c1")), temp_encoded_video( + num_frames=num_frames, fps=fps, directory=str(tmp_dir / "c2") + ): + video_duration = num_frames / fps + yield str(tmp_dir), video_duration @pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") From 2b00a8c200ee165464ec95162540098957fc61c2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 11 May 2023 21:50:20 +0200 Subject: [PATCH 38/39] ruff: enable PT (#1557) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 3 +++ .../core/integrations/labelstudio/input.py | 3 ++- tests/audio/classification/test_data.py | 2 +- tests/conftest.py | 14 +++++------- tests/core/data/test_base_viz.py | 2 +- tests/core/data/test_callback.py | 2 +- tests/core/data/test_data_module.py | 10 +++++---- tests/core/data/test_transforms.py | 4 ++-- tests/core/data/utilities/test_loading.py | 10 ++++----- .../integrations/vissl/test_strategies.py | 2 +- tests/core/optimizers/test_lr_scheduler.py | 2 +- tests/core/optimizers/test_optimizers.py | 4 ++-- tests/core/test_data.py | 12 +++++++--- tests/core/test_finetuning.py | 6 ++--- tests/core/test_model.py | 4 ++-- tests/core/test_trainer.py | 2 +- tests/core/utilities/test_embedder.py | 2 +- tests/core/utilities/test_lightning_cli.py | 22 ++++++++++--------- tests/core/utilities/test_stability.py | 2 +- tests/deprecated_api/test_remove_0_9_0.py | 2 +- tests/examples/test_scripts.py | 2 +- .../classification/test_active_learning.py | 4 ++-- tests/image/classification/test_data.py | 10 ++++----- .../detection/test_data_model_integration.py | 4 ++-- tests/image/embedding/test_model.py | 6 ++--- tests/image/instance_segm/test_model.py | 4 ++-- tests/image/keypoint_detection/test_model.py | 4 ++-- tests/image/test_backbones.py | 6 ++--- .../test_data_model_integration.py | 2 +- .../regression/test_data_model_integration.py | 6 ++--- tests/template/classification/test_model.py | 2 +- tests/video/classification/test_data.py | 4 ++-- 32 files changed, 88 insertions(+), 76 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1639f58e29..94d4a563d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,9 +73,12 @@ extend-select = [ "C4", # see: https://pypi.org/project/flake8-comprehensions "SIM", # see: https://pypi.org/project/flake8-simplify "RET", # see: https://pypi.org/project/flake8-return + "PT", # see: https://pypi.org/project/flake8-pytest-style ] ignore = [ "E731", # Do not assign a lambda expression, use a def + "PT011", # todo `pytest.raises(ValueError)` is too broad, set the `match` parameter or use a more specific exception + "PT012", # todo: `pytest.raises()` block should contain a single simple statement ] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/src/flash/core/integrations/labelstudio/input.py b/src/flash/core/integrations/labelstudio/input.py index a8956aed82..a288241e96 100644 --- a/src/flash/core/integrations/labelstudio/input.py +++ b/src/flash/core/integrations/labelstudio/input.py @@ -191,7 +191,8 @@ def _export_data_to_json(export_path: str, raw_data: List[Dict]) -> Dict: @staticmethod def _split_train_val_data(data: Dict, split: float = 0) -> List[Dict]: - assert split > 0 and split < 1 + assert split > 0 + assert split < 1 file_path = data.get("export_json", None) if not file_path: diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index bebf8e18d4..c1adaad03b 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -71,7 +71,7 @@ def test_from_filepaths(tmpdir, file_generator): @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.parametrize( - "data,from_function", + ("data", "from_function"), [ (torch.rand(3, 3, 64, 64), AudioClassificationData.from_tensors), (np.random.rand(3, 3, 64, 64), AudioClassificationData.from_numpy), diff --git a/tests/conftest.py b/tests/conftest.py index 8a6b303cc6..edf66a29e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,8 +21,8 @@ def hex(self): return str(self) -@pytest.fixture(scope="function", autouse=True) -def patch_decorators_uuid_generator_func(mocker: MockerFixture): +@pytest.fixture(autouse=True) +def patch_decorators_uuid_generator_func(mocker: MockerFixture): # noqa: PT004 call_num = 0 def _generate_sequential_uuid(): @@ -31,7 +31,6 @@ def _generate_sequential_uuid(): return UUID_String(f"callnum_{call_num}") mocker.patch("flash.core.serve.decorators.uuid4", side_effect=_generate_sequential_uuid) - yield @pytest.fixture(scope="session") @@ -55,7 +54,7 @@ def module_global_datadir(tmp_path_factory, original_global_datadir): return prep_global_datadir(tmp_path_factory, original_global_datadir) -@pytest.fixture(scope="function") +@pytest.fixture() def global_datadir(tmp_path_factory, original_global_datadir): return prep_global_datadir(tmp_path_factory, original_global_datadir) @@ -64,8 +63,7 @@ def global_datadir(tmp_path_factory, original_global_datadir): @pytest.fixture(scope="session") def squeezenet1_1_model(): - model = torchvision.models.squeezenet1_1(pretrained=True).eval() - yield model + return torchvision.models.squeezenet1_1(pretrained=True).eval() @pytest.fixture(scope="session") def lightning_squeezenet1_1_obj(): @@ -73,7 +71,7 @@ def lightning_squeezenet1_1_obj(): model = LightningSqueezenet() model.eval() - yield model + return model @pytest.fixture(scope="session") def squeezenet_servable(squeezenet1_1_model, session_global_datadir): @@ -84,7 +82,7 @@ def squeezenet_servable(squeezenet1_1_model, session_global_datadir): torch.jit.save(trace, fpth) model = Servable(fpth) - yield (model, fpth) + return (model, fpth) @pytest.fixture() def lightning_squeezenet_checkpoint_path(tmp_path): diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index 10bf8082c8..954750216f 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -168,7 +168,7 @@ def _get_result(function_name: str): dm.data_fetcher.reset() @pytest.mark.parametrize( - "func_names, valid", + ("func_names", "valid"), [ (["load_sample"], True), (["not_a_hook"], False), diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index 7794298d64..6b3d195945 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -28,7 +28,7 @@ @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @patch("pickle.dumps") # need to mock pickle or we get pickle error @patch("torch.save") # need to mock torch.save, or we get pickle error -def test_flash_callback(_, __, tmpdir): +def test_flash_callback(_, __, tmpdir): # noqa: PT019 """Test the callback hook system for fit.""" callback_mock = MagicMock() diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index 5614e32cf6..76f443ac91 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -95,12 +95,14 @@ def predict_per_batch_transform_on_device(self) -> Callable: assert len(dm.train_dataloader()) == 5 batch = next(iter(dm.train_dataloader())) assert batch.shape == torch.Size([2]) - assert batch.min() >= 0 and batch.max() < 10 + assert batch.min() >= 0 + assert batch.max() < 10 assert len(dm.val_dataloader()) == 5 batch = next(iter(dm.val_dataloader())) assert batch.shape == torch.Size([2]) - assert batch.min() >= 0 and batch.max() < 10 + assert batch.min() >= 0 + assert batch.max() < 10 class TestModel(Task): def training_step(self, batch, batch_idx): @@ -218,7 +220,7 @@ def val_load_sample(self, sample): self.val_load_sample_called = True return {"a": sample, "b": sample + 1} - def test_load_data(self, _): + def test_load_data(self, _): # noqa: PT019 return [[torch.rand(1), torch.rand(1)], [torch.rand(1), torch.rand(1)]] @@ -426,7 +428,7 @@ def validation_step(self, batch, batch_idx): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("sampler, callable", [(MagicMock(), True), (NonCallableMock(), False)]) +@pytest.mark.parametrize(("sampler", "callable"), [(MagicMock(), True), (NonCallableMock(), False)]) @patch("flash.core.data.data_module.DataLoader") def test_dataloaders_with_sampler(mock_dataloader, sampler, callable): train_input = TestInput(RunningStage.TRAINING, [1]) diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index fbc678acf7..3667ac74ee 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -24,7 +24,7 @@ class TestApplyToKeys: @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "sample, keys, expected", + ("sample", "keys", "expected"), [ ({DataKeys.INPUT: "test"}, DataKeys.INPUT, "test"), ( @@ -49,7 +49,7 @@ def test_forward(self, sample, keys, expected): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "transform, expected", + ("transform", "expected"), [ ( ApplyToKeys(DataKeys.INPUT, torch.nn.ReLU()), diff --git a/tests/core/data/utilities/test_loading.py b/tests/core/data/utilities/test_loading.py index f8a4c8f48f..3ea85e2d98 100644 --- a/tests/core/data/utilities/test_loading.py +++ b/tests/core/data/utilities/test_loading.py @@ -82,7 +82,7 @@ def write_tsv(file_path): @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize( - "extension,write", + ("extension", "write"), [(extension, write_image) for extension in IMG_EXTENSIONS] + [(extension, write_numpy) for extension in NP_EXTENSIONS] # it shouldn't try to expand glob patterns in filenames @@ -100,7 +100,7 @@ def test_load_image(tmpdir, extension, write): @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") @pytest.mark.parametrize( - "extension,write", + ("extension", "write"), [(extension, write_image) for extension in IMG_EXTENSIONS] + [(extension, write_numpy) for extension in NP_EXTENSIONS] + [(extension, write_audio) for extension in AUDIO_EXTENSIONS], @@ -116,7 +116,7 @@ def test_load_spectrogram(tmpdir, extension, write): @pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.") -@pytest.mark.parametrize("extension,write", [(extension, write_audio) for extension in AUDIO_EXTENSIONS]) +@pytest.mark.parametrize(("extension", "write"), [(extension, write_audio) for extension in AUDIO_EXTENSIONS]) def test_load_audio(tmpdir, extension, write): file_path = os.path.join(tmpdir, f"test{extension}") write(file_path) @@ -129,7 +129,7 @@ def test_load_audio(tmpdir, extension, write): @pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( - "extension,write", + ("extension", "write"), [(extension, write_csv) for extension in CSV_EXTENSIONS] + [(extension, write_tsv) for extension in TSV_EXTENSIONS], ) def test_load_data_frame(tmpdir, extension, write): @@ -142,7 +142,7 @@ def test_load_data_frame(tmpdir, extension, write): @pytest.mark.parametrize( - "path, loader, target_type", + ("path", "loader", "target_type"), [ pytest.param( "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", diff --git a/tests/core/integrations/vissl/test_strategies.py b/tests/core/integrations/vissl/test_strategies.py index c0118586cb..2a6faf9cfe 100644 --- a/tests/core/integrations/vissl/test_strategies.py +++ b/tests/core/integrations/vissl/test_strategies.py @@ -39,7 +39,7 @@ @pytest.mark.skipif(not (_TORCHVISION_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @pytest.mark.parametrize( - "training_strategy, head_name, loss_fn_class, head_class, hooks_list", + ("training_strategy", "head_name", "loss_fn_class", "head_class", "hooks_list"), [ ("barlow_twins", "barlow_twins_head", BarlowTwinsLoss, SimCLRHead, [TrainingSetupHook]), ( diff --git a/tests/core/optimizers/test_lr_scheduler.py b/tests/core/optimizers/test_lr_scheduler.py index b18f9b4a1f..703fc851aa 100644 --- a/tests/core/optimizers/test_lr_scheduler.py +++ b/tests/core/optimizers/test_lr_scheduler.py @@ -23,7 +23,7 @@ @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "lr, warmup_epochs, max_epochs, warmup_start_lr, eta_min", + ("lr", "warmup_epochs", "max_epochs", "warmup_start_lr", "eta_min"), [ (1, 10, 3200, 0.001, 0.0), (1e-4, 40, 300, 1e-6, 1e-5), diff --git a/tests/core/optimizers/test_optimizers.py b/tests/core/optimizers/test_optimizers.py index bdfff256ee..51b82233b2 100644 --- a/tests/core/optimizers/test_optimizers.py +++ b/tests/core/optimizers/test_optimizers.py @@ -21,7 +21,7 @@ @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "optim_fn, lr, kwargs", + ("optim_fn", "lr", "kwargs"), [ (LARS, 0.1, {}), (LARS, 0.1, {"weight_decay": 0.001}), @@ -44,7 +44,7 @@ def test_optim_call(tmpdir, optim_fn, lr, kwargs): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("optim_fn, lr", [(LARS, 0.1), (LAMB, 1e-3)]) +@pytest.mark.parametrize(("optim_fn", "lr"), [(LARS, 0.1), (LAMB, 1e-3)]) def test_optim_with_scheduler(tmpdir, optim_fn, lr): max_epochs = 10 layer = nn.Linear(10, 1) diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 522c7224e8..de7513713d 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -39,13 +39,19 @@ def test_init(): test_input = DatasetInput(RunningStage.TESTING, DummyDataset()) data_module = DataModule(train_input, batch_size=1) - assert data_module.train_dataset and not data_module.val_dataset and not data_module.test_dataset + assert data_module.train_dataset + assert not data_module.val_dataset + assert not data_module.test_dataset data_module = DataModule(train_input, val_input, batch_size=1) - assert data_module.train_dataset and data_module.val_dataset and not data_module.test_dataset + assert data_module.train_dataset + assert data_module.val_dataset + assert not data_module.test_dataset data_module = DataModule(train_input, val_input, test_input, batch_size=1) - assert data_module.train_dataset and data_module.val_dataset and data_module.test_dataset + assert data_module.train_dataset + assert data_module.val_dataset + assert data_module.test_dataset @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index 72115158fb..e49a1817f5 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -137,7 +137,7 @@ def on_train_epoch_start(self, trainer, pl_module): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "strategy, plugins", + ("strategy", "plugins"), [ ("no_freeze", None), ("freeze", None), @@ -222,7 +222,7 @@ def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "strategy,error", + ("strategy", "error"), [ (None, TypeError), ("chocolate", ValueError), @@ -246,7 +246,7 @@ def test_finetuning_errors_and_exceptions(strategy, error): @pytest.mark.parametrize( - "strategy_key, strategy_metadata", + ("strategy_key", "strategy_metadata"), [ ("no_freeze", None), ("freeze", None), diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 110d4f87c3..909f126d60 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -208,7 +208,7 @@ def test_classification_task_trainer_predict(tmpdir): @pytest.mark.parametrize( - ["cls", "filename"], + ("cls", "filename"), [ pytest.param( ImageClassifier, @@ -343,7 +343,7 @@ def custom_steplr_configuration_return_as_dict(optimizer): "optim", ["Adadelta", functools.partial(torch.optim.Adadelta, eps=0.5), ("Adadelta", {"eps": 0.5})] ) @pytest.mark.parametrize( - "sched, interval", + ("sched", "interval"), [ (None, "epoch"), ("custom_steplr_configuration_return_as_instance", "epoch"), diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index 928f990737..695185e372 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -68,7 +68,7 @@ def finetune_function( @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("callbacks, should_warn", [([], False), ([NoFreeze()], True)]) +@pytest.mark.parametrize(("callbacks", "should_warn"), [([], False), ([NoFreeze()], True)]) def test_trainer_fit(tmpdir, callbacks, should_warn): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) train_dl = DataLoader(DummyDataset()) diff --git a/tests/core/utilities/test_embedder.py b/tests/core/utilities/test_embedder.py index 89abd1c5f9..5f4c297d94 100644 --- a/tests/core/utilities/test_embedder.py +++ b/tests/core/utilities/test_embedder.py @@ -38,7 +38,7 @@ def __init__(self, n_layers): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("layer, size", [("backbone.1", 30), ("output", 40), ("", 40)]) +@pytest.mark.parametrize(("layer", "size"), [("backbone.1", 30), ("output", 40), ("", 40)]) def test_embedder(layer, size): """Tests that the embedder ``predict_step`` correctly returns the output from the requested layer.""" model = EmbedderTestModel( diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index 8fcd81eadb..f8832547f0 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -74,7 +74,7 @@ def test_add_argparse_args_redefined(cli_args): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - ["cli_args", "expected"], + ("cli_args", "expected"), [ ("--auto_lr_find=True --auto_scale_batch_size=power", {"auto_lr_find": True, "auto_scale_batch_size": "power"}), ( @@ -103,7 +103,7 @@ def test_parse_args_parsing(cli_args, expected): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - ["cli_args", "expected", "instantiate"], + ("cli_args", "expected", "instantiate"), [ (["--gpus", "[0, 2]"], {"gpus": [0, 2]}, False), (["--tpu_cores=[1,3]"], {"tpu_cores": [1, 3]}, False), @@ -125,7 +125,7 @@ def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - ["cli_args", "expected_gpu"], + ("cli_args", "expected_gpu"), [ ("--gpus 1", [0]), ("--gpus 0,", [0]), @@ -148,7 +148,7 @@ def test_parse_args_parsing_gpus(mocker, cli_args, expected_gpu): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - ["cli_args", "extra_args"], + ("cli_args", "extra_args"), [ ({}, {}), ({"logger": False}, {}), @@ -188,7 +188,7 @@ def trainer_builder( @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (trainer_builder, model_builder)]) +@pytest.mark.parametrize(("trainer_class", "model_class"), [(Trainer, Model), (trainer_builder, model_builder)]) def test_lightning_cli(trainer_class, model_class, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" @@ -217,7 +217,8 @@ def on_train_start(callback, trainer, _): with patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]): cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback) - assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts + assert hasattr(cli.trainer, "ran_asserts") + assert cli.trainer.ran_asserts class TestModelCallbacks(BoringModel): @@ -307,7 +308,8 @@ def test_lightning_cli_args(tmpdir): assert os.path.isfile(config_path) with open(config_path) as f: config = yaml.safe_load(f.read()) - assert "model" not in config and "model" not in cli.config # no arguments to include + assert "model" not in config + assert "model" not in cli.config assert config["data"] == cli.config["data"] assert config["trainer"] == cli.config["trainer"] @@ -577,13 +579,13 @@ def on_exception(self, execption): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("logger", (False, True)) +@pytest.mark.parametrize("logger", [False, True]) @pytest.mark.parametrize( "trainer_kwargs", - ( + [ {"accelerator": "cpu", "strategy": "ddp"}, {"accelerator": "cpu", "strategy": "ddp", "plugins": "ddp_find_unused_parameters_false"}, - ), + ], ) @pytest.mark.xfail(reason="Bugs in PL >= 1.6.0") def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): diff --git a/tests/core/utilities/test_stability.py b/tests/core/utilities/test_stability.py index cec2fb247c..16916be03f 100644 --- a/tests/core/utilities/test_stability.py +++ b/tests/core/utilities/test_stability.py @@ -39,7 +39,7 @@ def _beta_func_custom_message(): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") @pytest.mark.parametrize( - "callable, match", + ("callable", "match"), [ (_BetaType, "This feature is currently in Beta."), (_BetaTypeCustomMessage, "_BetaTypeCustomMessage is currently in Beta."), diff --git a/tests/deprecated_api/test_remove_0_9_0.py b/tests/deprecated_api/test_remove_0_9_0.py index be8a2090ac..363a0d788b 100644 --- a/tests/deprecated_api/test_remove_0_9_0.py +++ b/tests/deprecated_api/test_remove_0_9_0.py @@ -19,7 +19,7 @@ @pytest.mark.skipif(not _VISSL_AVAILABLE, reason="vissl not installed.") @pytest.mark.parametrize( - "deprecated_backbone, alternative_backbone", + ("deprecated_backbone", "alternative_backbone"), [("resnet", "resnet50"), ("vision_transformer", "vit_small_patch16_224")], ) def test_0_9_0_embedder_models(deprecated_backbone, alternative_backbone): diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 80fea16b43..745912852a 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -45,7 +45,7 @@ @patch.dict(os.environ, {"FLASH_TESTING": "1"}) @pytest.mark.parametrize( - "folder,fname", + ("folder", "fname"), [ pytest.param( "", diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index 4e1164788e..6ad08932b1 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -30,7 +30,7 @@ # ======== Mock functions ======== -@pytest.fixture +@pytest.fixture() def simple_datamodule(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -63,7 +63,7 @@ def simple_datamodule(tmpdir): @pytest.mark.skipif( not (_TOPIC_IMAGE_AVAILABLE and _BAAL_AVAILABLE), reason="image and baal libraries aren't installed." ) -@pytest.mark.parametrize("initial_num_labels, query_size", [(0, 5), (5, 5)]) +@pytest.mark.parametrize(("initial_num_labels", "query_size"), [(0, 5), (5, 5)]) def test_active_learning_training(simple_datamodule, initial_num_labels, query_size): seed_everything(42) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 87e5ecae7b..1aace9a033 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -419,7 +419,7 @@ def test_from_filepaths_multilabel(tmpdir): @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.parametrize( - "data,from_function", + ("data", "from_function"), [ (torch.rand(3, 3, 196, 196), ImageClassificationData.from_tensors), (np.random.rand(3, 3, 196, 196), ImageClassificationData.from_numpy), @@ -546,7 +546,7 @@ def test_from_datasets(): assert labels.shape == (2,) -@pytest.fixture +@pytest.fixture() def image_tmpdir(tmpdir): (tmpdir / "train").mkdir() Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_1.png")) @@ -554,7 +554,7 @@ def image_tmpdir(tmpdir): return tmpdir / "train" -@pytest.fixture +@pytest.fixture() def single_target_csv(image_tmpdir): with open(image_tmpdir / "metadata.csv", "w") as csvfile: fieldnames = ["image", "target"] @@ -582,7 +582,7 @@ def test_from_csv_single_target(single_target_csv): assert labels.shape == (2,) -@pytest.fixture +@pytest.fixture() def multi_target_csv(image_tmpdir): with open(image_tmpdir / "metadata.csv", "w") as csvfile: fieldnames = ["image", "target_1", "target_2"] @@ -610,7 +610,7 @@ def test_from_csv_multi_target(multi_target_csv): assert labels.shape == (2, 2) -@pytest.fixture +@pytest.fixture() def bad_csv_no_image(image_tmpdir): with open(image_tmpdir / "metadata.csv", "w") as csvfile: fieldnames = ["image", "target"] diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py index fd45a57073..4ebf18752f 100644 --- a/tests/image/detection/test_data_model_integration.py +++ b/tests/image/detection/test_data_model_integration.py @@ -34,7 +34,7 @@ @pytest.mark.skipif(not _COCO_AVAILABLE, reason="coco is not installed for testing") -@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn"), ("faster_rcnn", "resnet18_fpn")]) +@pytest.mark.parametrize(("head", "backbone"), [("retinanet", "resnet18_fpn"), ("faster_rcnn", "resnet18_fpn")]) def test_detection(tmpdir, head, backbone): train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir) @@ -56,7 +56,7 @@ def test_detection(tmpdir, head, backbone): @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing") -@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")]) +@pytest.mark.parametrize(("head", "backbone"), [("retinanet", "resnet18_fpn")]) def test_detection_fiftyone(tmpdir, head, backbone): train_dataset = _create_synth_fiftyone_dataset(tmpdir) diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py index 671a1ea330..7848fdd84f 100644 --- a/tests/image/embedding/test_model.py +++ b/tests/image/embedding/test_model.py @@ -51,7 +51,7 @@ def check_forward_output(self, output: Any): @pytest.mark.skipif(torch.cuda.device_count() > 1, reason="VISSL integration doesn't support multi-GPU") @pytest.mark.skipif(not (_TOPIC_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @pytest.mark.parametrize( - "backbone, training_strategy, head, pretraining_transform, embedding_size", + ("backbone", "training_strategy", "head", "pretraining_transform", "embedding_size"), [ ("resnet18", "simclr", "simclr_head", "simclr_transform", 512), ("resnet18", "barlow_twins", "barlow_twins_head", "barlow_twins_transform", 512), @@ -89,7 +89,7 @@ def test_vissl_training(backbone, training_strategy, head, pretraining_transform @pytest.mark.skipif(not (_TOPIC_IMAGE_AVAILABLE and _VISSL_AVAILABLE), reason="vissl not installed.") @pytest.mark.parametrize( - "backbone, training_strategy, head, pretraining_transform, expected_exception", + ("backbone", "training_strategy", "head", "pretraining_transform", "expected_exception"), [ ("resnet18", "simclr", "simclr_head", None, ValueError), ("resnet18", "simclr", None, "simclr_transform", KeyError), @@ -109,7 +109,7 @@ def test_vissl_training_with_wrong_arguments( @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="torch vision not installed.") @pytest.mark.parametrize( - "backbone, embedding_size", + ("backbone", "embedding_size"), [ ("resnet18", 512), ("vit_small_patch16_224", 384), diff --git a/tests/image/instance_segm/test_model.py b/tests/image/instance_segm/test_model.py index e4637cb578..ae9fb1059b 100644 --- a/tests/image/instance_segm/test_model.py +++ b/tests/image/instance_segm/test_model.py @@ -32,7 +32,7 @@ COCODataConfig = collections.namedtuple("COCODataConfig", "train_folder train_ann_file predict_folder") -@pytest.fixture +@pytest.fixture() def coco_instances(tmpdir): rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) os.makedirs(tmpdir / "train_folder", exist_ok=True) @@ -134,7 +134,7 @@ def example_test_sample(self): @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.parametrize("backbone, head", [("resnet18_fpn", "mask_rcnn")]) +@pytest.mark.parametrize(("backbone", "head"), [("resnet18_fpn", "mask_rcnn")]) def test_model(coco_instances, backbone, head): datamodule = InstanceSegmentationData.from_coco( train_folder=coco_instances.train_folder, diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py index 8c18a45826..02769e503b 100644 --- a/tests/image/keypoint_detection/test_model.py +++ b/tests/image/keypoint_detection/test_model.py @@ -32,7 +32,7 @@ COCODataConfig = collections.namedtuple("COCODataConfig", "train_folder train_ann_file predict_folder") -@pytest.fixture +@pytest.fixture() def coco_keypoints(tmpdir): rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) os.makedirs(tmpdir / "train_folder", exist_ok=True) @@ -141,7 +141,7 @@ def example_test_sample(self): @pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.") -@pytest.mark.parametrize("backbone, head", [("resnet18_fpn", "keypoint_rcnn")]) +@pytest.mark.parametrize(("backbone", "head"), [("resnet18_fpn", "keypoint_rcnn")]) def test_model(coco_keypoints, backbone, head): datamodule = KeypointDetectionData.from_coco( train_folder=coco_keypoints.train_folder, diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py index 842027b571..5d82b004d0 100644 --- a/tests/image/test_backbones.py +++ b/tests/image/test_backbones.py @@ -21,7 +21,7 @@ @pytest.mark.parametrize( - ["backbone", "expected_num_features"], + ("backbone", "expected_num_features"), [ pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No torchvision")), pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="No timm")), @@ -38,7 +38,7 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features): @pytest.mark.parametrize( - ["backbone", "pretrained", "expected_num_features"], + ("backbone", "pretrained", "expected_num_features"), [ pytest.param( "resnet50", @@ -59,7 +59,7 @@ def test_pretrained_weights_registry(backbone, pretrained, expected_num_features @pytest.mark.parametrize( - ["backbone", "pretrained"], + ("backbone", "pretrained"), [ pytest.param("resnet50w2", True), pytest.param("resnet50w4", "supervised"), diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index 010d58b275..f52f045830 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -32,7 +32,7 @@ @pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( - "backbone,fields", + ("backbone", "fields"), [ ("tabnet", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("tabtransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), diff --git a/tests/tabular/regression/test_data_model_integration.py b/tests/tabular/regression/test_data_model_integration.py index a2144c8ee9..0a01bac532 100644 --- a/tests/tabular/regression/test_data_model_integration.py +++ b/tests/tabular/regression/test_data_model_integration.py @@ -41,7 +41,7 @@ @pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( - "backbone,fields", + ("backbone", "fields"), [ ("tabnet", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("tabtransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), @@ -75,7 +75,7 @@ def test_regression_data_frame(backbone, fields, tmpdir): @pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( - "backbone,fields", + ("backbone", "fields"), [ ("tabnet", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("tabtransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), @@ -106,7 +106,7 @@ def test_regression_dicts(backbone, fields, tmpdir): @pytest.mark.skipif(not _TOPIC_TABULAR_AVAILABLE, reason="tabular libraries aren't installed.") @pytest.mark.parametrize( - "backbone,fields", + ("backbone", "fields"), [ ("tabnet", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("tabtransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index ca9b7d455a..5b5f54d649 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -123,7 +123,7 @@ def test_predict_sklearn(): @pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.") -@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))]) +@pytest.mark.parametrize(("jitter", "args"), [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))]) @pytest.mark.xfail(RuntimeError, reason="TemplateSKLearnClassifier is not attached to a `Trainer`") # fixme def test_jit(tmpdir, jitter, args): path = os.path.join(tmpdir, "testing_model.pt") diff --git a/tests/video/classification/test_data.py b/tests/video/classification/test_data.py index c6cd9d2b0d..d174e2bd5f 100644 --- a/tests/video/classification/test_data.py +++ b/tests/video/classification/test_data.py @@ -63,7 +63,7 @@ def _check_frames(data, expected_frames_count: Union[list, int]): @pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") @pytest.mark.parametrize( - "input_data, input_targets, expected_frames_count", + ("input_data", "input_targets", "expected_frames_count"), [ ([temp_encoded_tensors(5), temp_encoded_tensors(5)], ["label1", "label2"], [5, 5]), ([temp_encoded_tensors(5), temp_encoded_tensors(10)], ["label1", "label2"], [5, 10]), @@ -81,7 +81,7 @@ def test_load_data_from_tensors(input_data, input_targets, expected_frames_count @pytest.mark.skipif(not _TOPIC_VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.") @pytest.mark.parametrize( - "input_data, input_targets, error_type, match", + ("input_data", "input_targets", "error_type", "match"), [ (torch.tensor(1), ["label1"], ValueError, "dimension should be"), (torch.randint(size=(2, 3), low=0, high=255), ["label"], ValueError, "dimension should be"), From 59dab9bfbc81e2dc0869f34a10835cbe48783d3b Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 12 May 2023 11:00:34 +0200 Subject: [PATCH 39/39] ci: test min extras (#1558) * ci: test min extras * tests --- .github/dependabot.yml | 2 +- .github/workflows/ci-testing.yml | 5 +++-- requirements/testing_vision.txt | 0 tests/vision/.gitkeep | 0 4 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 requirements/testing_vision.txt create mode 100644 tests/vision/.gitkeep diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 2b16b225d3..7ea5aa01f5 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -8,7 +8,7 @@ updates: directory: "/" # Check for updates once a week schedule: - interval: "monthly" + interval: "weekly" # Labels on pull requests for version updates only labels: ["enhancement"] pull-request-branch-name: diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ed39516f73..6c9496c147 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -42,6 +42,8 @@ jobs: - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'graph', extra: [] } - { os: 'ubuntu-20.04', python-version: 3.9, topic: 'audio', extra: [] } - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'core', extra: [], requires: 'oldest' } + - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'serve', extra: [], requires: 'oldest' } + - { os: 'ubuntu-20.04', python-version: 3.8, topic: 'vision', extra: [], requires: 'oldest' } # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 50 @@ -74,8 +76,7 @@ jobs: if: matrix.requires == 'oldest' run: | import glob, os - # FixMe: shall be minimal for ALL dependencies not only base - # files = glob.glob(os.path.join("requirements", "*.txt")) + ['requirements.txt'] + files = glob.glob(os.path.join("requirements", "*.txt")) + ['requirements.txt'] files = ['requirements.txt'] for fname in files: lines = [line.replace('>=', '==') for line in open(fname).readlines()] diff --git a/requirements/testing_vision.txt b/requirements/testing_vision.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/vision/.gitkeep b/tests/vision/.gitkeep new file mode 100644 index 0000000000..e69de29bb2