Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Nov 9, 2021
1 parent bd9d886 commit 12c4f0f
Show file tree
Hide file tree
Showing 16 changed files with 29 additions and 49 deletions.
5 changes: 1 addition & 4 deletions docs/extensions/autodatasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ def _resolve_transforms(_):
return None

input_transform = PatchedInputTransform()
inputs = {
: input_transform._of_name()
for in input_transform.available_inputs()
}
inputs = {input: input_transform.input_of_name(input) for input in input_transform.available_inputs()}

ENVIRONMENT.get_template("base.rst")

Expand Down
6 changes: 3 additions & 3 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,15 @@ Here is the :class:`~flash.core.data.auto_dataset.AutoDataset` pseudo-code.
def __init__(
self,
data: List[Any], # output of `Input.load_data`
: Input,
input: Input,
running_stage: RunningStage,
):
self.data = data
self. =
self.input = input
def __getitem__(self, index: int):
return self..load_sample(self.data[index])
return self.input.load_sample(self.data[index])
def __len__(self):
return len(self.data)
Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flash.core.data.callback import FlashCallback
from flash.core.data.data_module import DataModule # noqa: E402
from flash.core.data.datasets import FlashDataset, FlashIterableDataset
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
Expand Down
8 changes: 1 addition & 7 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@

import flash
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import (
DatasetInput,
Input,
InputDataKeys,
InputFormat,
PathsInput,
)
from flash.core.data.io.input import DatasetInput, Input, InputDataKeys, InputFormat, PathsInput
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Deserializer
Expand Down
8 changes: 2 additions & 6 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Sequence, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING

import torch

from flash.core.data.callback import ControlFlow
from flash.core.data.utils import (
convert_to_modules,
CurrentFuncContext,
CurrentRunningStageContext,
)
from flash.core.data.utils import convert_to_modules, CurrentFuncContext, CurrentRunningStageContext
from flash.core.utilities.stages import RunningStage

if TYPE_CHECKING:
Expand Down
12 changes: 6 additions & 6 deletions flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data._utils.collate import default_collate

from flash.core.data.callback import ControlFlow, FlashCallback
from flash.core.data.io.input import Input, InputFormat, DatasetInput
from flash.core.data.io.input import DatasetInput, Input, InputDataKeys, InputFormat
from flash.core.data.process import Deserializer
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.states import (
Expand Down Expand Up @@ -268,9 +268,9 @@ def _check_transforms(
return transform

if isinstance(transform, list):
transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.Sequential(*transform))}
transform = {"pre_tensor_transform": ApplyToKeys(InputDataKeys.INPUT, torch.nn.Sequential(*transform))}
elif callable(transform):
transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, transform)}
transform = {"pre_tensor_transform": ApplyToKeys(InputDataKeys.INPUT, transform)}

if not isinstance(transform, Dict):
raise MisconfigurationException(
Expand Down Expand Up @@ -444,7 +444,7 @@ def collate(self, samples: Sequence, metadata=None) -> Any:
# return collate_fn.collate_fn(samples)

parameters = inspect.signature(collate_fn).parameters
if len(parameters) > 1 and DefaultDataKeys.METADATA in parameters:
if len(parameters) > 1 and InputDataKeys.METADATA in parameters:
return collate_fn(samples, metadata)
return collate_fn(samples)

Expand Down Expand Up @@ -655,7 +655,7 @@ def __init__(
def _extract_metadata(
samples: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]:
metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples]
metadata = [s.pop(InputDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples]
return samples, metadata if any(m is not None for m in metadata) else None

def forward(self, samples: Sequence[Any]) -> Any:
Expand Down Expand Up @@ -689,7 +689,7 @@ def forward(self, samples: Sequence[Any]) -> Any:
except TypeError:
samples = self.collate_fn(samples)
if metadata and isinstance(samples, dict):
samples[DefaultDataKeys.METADATA] = metadata
samples[InputDataKeys.METADATA] = metadata
self.callback.on_collate(samples, self.stage)

with self._per_batch_transform_context:
Expand Down
2 changes: 1 addition & 1 deletion flash/core/integrations/labelstudio/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem

from flash.core.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.core.data.io.input import Input, InputDataKeys, has_len
from flash.core.data.io.input import has_len, Input, InputDataKeys
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE
from flash.core.utilities.stages import RunningStage
Expand Down
4 changes: 1 addition & 3 deletions flash/core/utilities/flash_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ def add_arguments_to_parser(self, parser) -> None:
if isinstance(input, InputFormat):
input = input.value
if hasattr(self.local_datamodule_class, f"from_{input}"):
self.add_subcommand_from_function(
subcommands, getattr(self.local_datamodule_class, f"from_{input}")
)
self.add_subcommand_from_function(subcommands, getattr(self.local_datamodule_class, f"from_{input}"))

for datamodule_builder in self.additional_datamodule_builders:
self.add_subcommand_from_function(subcommands, datamodule_builder)
Expand Down
2 changes: 1 addition & 1 deletion flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import InputDataKeys, InputFormat, FiftyOneInput
from flash.core.data.io.input import FiftyOneInput, InputDataKeys, InputFormat
from flash.core.data.io.input_transform import InputTransform
from flash.core.integrations.icevision.data import IceVisionParserInput, IceVisionPathsInput
from flash.core.integrations.icevision.transforms import default_transforms
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch

import flash
from flash.core.data.io.input import InputDataKeys, ImageLabelsMap
from flash.core.data.io.input import ImageLabelsMap, InputDataKeys
from flash.core.data.io.output import Output
from flash.core.utilities.imports import (
_FIFTYONE_AVAILABLE,
Expand Down
10 changes: 5 additions & 5 deletions flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.nn import functional as F

from flash.core.data. import DefaultDataKeys
from flash.core.data.io.input import InputDataKeys
from flash.core.regression import RegressionTask
from flash.core.utilities.imports import _TABULAR_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE
Expand Down Expand Up @@ -90,19 +90,19 @@ def forward(self, x_in) -> torch.Tensor:
return self.model(x)[0].flatten()

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
batch = (batch[InputDataKeys.INPUT], batch[InputDataKeys.TARGET])
return super().training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
batch = (batch[InputDataKeys.INPUT], batch[InputDataKeys.TARGET])
return super().validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
batch = (batch[InputDataKeys.INPUT], batch[InputDataKeys.TARGET])
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = batch[DefaultDataKeys.INPUT]
batch = batch[InputDataKeys.INPUT]
return self(batch)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _initialize_model_specific_parameters(self):

@property
def tokenizer(self) -> "PreTrainedTokenizerBase":
return self.data_pipeline..tokenizer
return self.data_pipeline.input.tokenizer

def tokenize_labels(self, labels: Tensor) -> List[str]:
label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
Expand Down
8 changes: 1 addition & 7 deletions flash/video/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@
from torch.utils.data import Sampler

from flash.core.data.data_module import DataModule
from flash.core.data.io.input import (
InputDataKeys,
InputFormat,
FiftyOneInput,
LabelsState,
PathsInput,
)
from flash.core.data.io.input import FiftyOneInput, InputDataKeys, InputFormat, LabelsState, PathsInput
from flash.core.data.io.input_transform import InputTransform
from flash.core.integrations.labelstudio.input import LabelStudioVideoClassificationInput
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/flash_components/custom_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from torch.utils.data._utils.collate import default_collate

from flash import _PACKAGE_ROOT, FlashDataset
from flash.core.data.io.input import InputDataKeys
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.data.io.input import InputDataKeys
from flash.core.data.new_data_module import DataModule
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import download_data
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_predict_dataset(tmpdir):
tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform())
out = model.predict(tudataset, ="datasets", data_pipeline=data_pipe)
out = model.predict(tudataset, input="datasets", data_pipeline=data_pipe)
assert isinstance(out[0], int)


Expand Down
4 changes: 2 additions & 2 deletions tests/image/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_predict_tensor():
img = torch.rand(1, 3, 64, 64)
model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform(num_classes=1))
out = model.predict(img, ="tensors", data_pipeline=data_pipe)
out = model.predict(img, input="tensors", data_pipeline=data_pipe)
assert isinstance(out[0], list)
assert len(out[0]) == 64
assert len(out[0][0]) == 64
Expand All @@ -118,7 +118,7 @@ def test_predict_numpy():
img = np.ones((1, 3, 64, 64))
model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
data_pipe = DataPipeline(input_transform=SemanticSegmentationInputTransform(num_classes=1))
out = model.predict(img, ="numpy", data_pipeline=data_pipe)
out = model.predict(img, input="numpy", data_pipeline=data_pipe)
assert isinstance(out[0], list)
assert len(out[0]) == 64
assert len(out[0][0]) == 64
Expand Down

0 comments on commit 12c4f0f

Please sign in to comment.