diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 2c7b5ff9ba..8e08e095ba 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -22,6 +22,7 @@ from torch.optim import Optimizer from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _PL_AVAILABLE class FinetuningStrategies(LightningEnum): @@ -170,11 +171,13 @@ def finetune_function( ] _FINETUNING_STRATEGIES_REGISTRY = FlashRegistry("finetuning_strategies") -for strategy in FinetuningStrategies: - _FINETUNING_STRATEGIES_REGISTRY( - name=strategy.value, - fn=partial(FlashBaseFinetuning, strategy_key=strategy), - ) + +if _PL_AVAILABLE: + for strategy in FinetuningStrategies: + _FINETUNING_STRATEGIES_REGISTRY( + name=strategy.value, + fn=partial(FlashBaseFinetuning, strategy_key=strategy), + ) class NoFreeze(FlashBaseFinetuning): diff --git a/flash/core/utilities/flash_cli.py b/flash/core/utilities/flash_cli.py index 63d5e5defd..bd0992afba 100644 --- a/flash/core/utilities/flash_cli.py +++ b/flash/core/utilities/flash_cli.py @@ -26,7 +26,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden import flash -from flash import DataModule +from flash.core.data.data_module import DataModule from flash.core.data.io.input import InputFormat from flash.core.utilities.lightning_cli import ( class_from_function, diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 6e643b16c9..581c5cd719 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -66,6 +66,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_AVAILABLE = _module_available("torch") +_PL_AVAILABLE = _module_available("pytorch_lightning") _BOLTS_AVAILABLE = _module_available("pl_bolts") and _compare_version("torch", operator.lt, "1.9.0") _PANDAS_AVAILABLE = _module_available("pandas") _SKLEARN_AVAILABLE = _module_available("sklearn")