Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix RTD Build (#887)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Oct 25, 2021
1 parent 1dcdbbe commit 2d1b242
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 36 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _package_list_from_file(pfile):
"pytorch-tabnet": "pytorch_tabnet",
"pyDeprecate": "deprecate",
}
MOCK_PACKAGES = ["PyYAML", "tqdm"]
MOCK_PACKAGES = ["numpy", "PyYAML", "tqdm"]
if SPHINX_MOCK_REQUIREMENTS:
# mock also base packages when we are on RTD since we don't install them there
MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt"))
Expand Down
Empty file.
13 changes: 6 additions & 7 deletions flash/core/integrations/labelstudio/data_source.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import json
import os
from pathlib import Path
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
from typing import Any, Mapping, Optional, Sequence, Union

import torch
from pytorch_lightning.utilities.cloud_io import get_filesystem

from flash import DataSource
from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.core.data.data_source import DefaultDataKeys, has_len
from flash.core.data.data_source import DataSource, DefaultDataKeys, has_len
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader
DATA_TYPE = TypeVar("DATA_TYPE")

if _TEXT_AVAILABLE:
from transformers import AutoTokenizer


class LabelStudioDataSource(DataSource):
Expand Down Expand Up @@ -80,7 +81,7 @@ def load_sample(self, sample: Mapping[str, Any] = None, dataset: Optional[Any] =

def generate_dataset(
self,
data: Optional[DATA_TYPE],
data: Optional[Any],
running_stage: RunningStage,
) -> Optional[Union[AutoDataset, IterableAutoDataset]]:
"""Generate dataset from loaded data."""
Expand Down Expand Up @@ -201,8 +202,6 @@ class LabelStudioTextClassificationDataSource(LabelStudioDataSource):
def __init__(self, backbone=None, max_length=128):
super().__init__()
if backbone:
if _TEXT_AVAILABLE:
from transformers import AutoTokenizer
self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.max_length = max_length
Expand Down
9 changes: 5 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@
SerializerMapping,
)
from flash.core.data.properties import ProcessState
from flash.core.optimizers import _OPTIMIZERS_REGISTRY, _SCHEDULERS_REGISTRY
from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY
from flash.core.optimizers.schedulers import _SCHEDULERS_REGISTRY
from flash.core.registry import FlashRegistry
from flash.core.serve import Composition
from flash.core.utilities import providers
from flash.core.serve.composition import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires
from flash.core.utilities.providers import _HUGGINGFACE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import (
DESERIALIZER_TYPE,
Expand Down Expand Up @@ -979,7 +980,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:

# Providers part
if lr_scheduler_metadata is not None and "providers" in lr_scheduler_metadata.keys():
if lr_scheduler_metadata["providers"] == providers._HUGGINGFACE:
if lr_scheduler_metadata["providers"] == _HUGGINGFACE:
if lr_scheduler_data["name"] != "constant_schedule":
num_training_steps: int = self.get_num_training_steps()
num_warmup_steps: int = self._compute_warmup(
Expand Down
17 changes: 9 additions & 8 deletions flash/core/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
from torch import optim

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE
from flash.core.utilities.imports import _TORCH_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE

_OPTIMIZERS_REGISTRY = FlashRegistry("optimizer")

_optimizers: List[Callable] = []
for n in dir(optim):
_optimizer = getattr(optim, n)
if _TORCH_AVAILABLE:
_optimizers: List[Callable] = []
for n in dir(optim):
_optimizer = getattr(optim, n)

if isclass(_optimizer) and _optimizer != optim.Optimizer and issubclass(_optimizer, optim.Optimizer):
_optimizers.append(_optimizer)
if isclass(_optimizer) and _optimizer != optim.Optimizer and issubclass(_optimizer, optim.Optimizer):
_optimizers.append(_optimizer)

for fn in _optimizers:
_OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower())
for fn in _optimizers:
_OPTIMIZERS_REGISTRY(fn, name=fn.__name__.lower())


if _TORCH_OPTIMIZER_AVAILABLE:
Expand Down
24 changes: 12 additions & 12 deletions flash/core/optimizers/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@
)

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE
from flash.core.utilities.imports import _TORCH_AVAILABLE, _TRANSFORMERS_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE

_SCHEDULERS_REGISTRY = FlashRegistry("scheduler")
_STEP_SCHEDULERS = (StepLR, MultiStepLR, CosineAnnealingLR, CyclicLR, CosineAnnealingWarmRestarts)

schedulers: List[_LRScheduler] = []
for n in dir(lr_scheduler):
sched = getattr(lr_scheduler, n)
if _TORCH_AVAILABLE:
schedulers: List[_LRScheduler] = []
for n in dir(lr_scheduler):
sched = getattr(lr_scheduler, n)

if inspect.isclass(sched) and sched != _LRScheduler and issubclass(sched, _LRScheduler):
schedulers.append(sched)
if inspect.isclass(sched) and sched != _LRScheduler and issubclass(sched, _LRScheduler):
schedulers.append(sched)

# Adding `ReduceLROnPlateau` separately as it is subclassed from `object` and not `_LRScheduler`.
schedulers.append(ReduceLROnPlateau)
# Adding `ReduceLROnPlateau` separately as it is subclassed from `object` and not `_LRScheduler`.
schedulers.append(ReduceLROnPlateau)


for scheduler in schedulers:
interval = "step" if issubclass(scheduler, _STEP_SCHEDULERS) else "epoch"
_SCHEDULERS_REGISTRY(scheduler, name=scheduler.__name__.lower(), interval=interval)
for scheduler in schedulers:
interval = "step" if issubclass(scheduler, _STEP_SCHEDULERS) else "epoch"
_SCHEDULERS_REGISTRY(scheduler, name=scheduler.__name__.lower(), interval=interval)

if _TRANSFORMERS_AVAILABLE:
from transformers import optimization
Expand Down
2 changes: 1 addition & 1 deletion flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, Dataset, random_split

from flash import DataModule
from flash.core.data.auto_dataset import BaseAutoDataset
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipeline
from flash.core.utilities.imports import _BAAL_AVAILABLE, requires

Expand Down
3 changes: 0 additions & 3 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@ sphinx-paramlinks>=0.5.1
sphinx-togglebutton>=0.2
sphinx-copybutton>=0.3
jinja2
numpy>=1.21.2 # hotfix for docs, failing without numpy install
pandas
torch

0 comments on commit 2d1b242

Please sign in to comment.