Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into issue_191-more_return_types
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored Mar 29, 2021
2 parents c6c2288 + 3b4c6b6 commit 0109236
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 45 deletions.
3 changes: 2 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
from typing import Any, Union

import torch
from torch import Tensor

from flash.core.data import TaskDataPipeline
from flash.core.model import Task


class ClassificationDataPipeline(TaskDataPipeline):

def before_uncollate(self, batch: Union[torch.Tensor, tuple]) -> torch.Tensor:
def before_uncollate(self, batch: Union[Tensor, tuple]) -> Tensor:
if isinstance(batch, tuple):
batch = batch[0]
return torch.softmax(batch, -1)
Expand Down
3 changes: 2 additions & 1 deletion flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import requests
import torch
from torch import Tensor
from tqdm.auto import tqdm as tq


Expand Down Expand Up @@ -81,7 +82,7 @@ def download_data(url: str, path: str = "data/") -> None:
download_file(url, path)


def _contains_any_tensor(value: Any, dtype: Type = torch.Tensor) -> bool:
def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
# TODO: we should refactor FlashDatasetFolder to better integrate
# with DataPipeline. That way, we wouldn't need this check.
# This is because we are running transforms in both places.
Expand Down
3 changes: 2 additions & 1 deletion flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.utils import _contains_any_tensor, convert_to_modules

Expand Down Expand Up @@ -178,7 +179,7 @@ def default_uncollate(batch: Any):

batch_type = type(batch)

if isinstance(batch, torch.Tensor):
if isinstance(batch, Tensor):
return list(torch.unbind(batch, 0))

elif isinstance(batch, Mapping):
Expand Down
5 changes: 3 additions & 2 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor
from torch.nn import Module
from torch.utils.data._utils.collate import default_collate

Expand Down Expand Up @@ -101,10 +102,10 @@ def load_sample(cls, sample: Any, dataset: Optional[Any] = None) -> Any:
def pre_tensor_transform(self, sample: Any) -> Any:
return sample

def to_tensor_transform(self, sample: Any) -> torch.Tensor:
def to_tensor_transform(self, sample: Any) -> Tensor:
return sample

def post_tensor_transform(self, sample: torch.Tensor) -> torch.Tensor:
def post_tensor_transform(self, sample: Tensor) -> Tensor:
return sample

def per_batch_transform(self, batch: Any) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion flash/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import apply_to_collection
from torch import Tensor
from tqdm.auto import tqdm as tq

_STAGES_PREFIX = {
Expand Down Expand Up @@ -69,7 +70,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
zip_ref.extractall(path)


def _contains_any_tensor(value: Any, dtype: Type = torch.Tensor) -> bool:
def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
# TODO: we should refactor FlashDatasetFolder to better integrate
# with DataPipeline. That way, we wouldn't need this check.
# This is because we are running transforms in both places.
Expand Down
3 changes: 1 addition & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def after_collate(self, batch: Tensor) -> Tensor:
batch["input_ids"] = batch["input_ids"].squeeze(0)
return batch

def before_uncollate(self, batch: Union[torch.Tensor, tuple,
SequenceClassifierOutput]) -> Union[tuple, torch.Tensor]:
def before_uncollate(self, batch: Union[Tensor, tuple, SequenceClassifierOutput]) -> Union[tuple, Tensor]:
if isinstance(batch, SequenceClassifierOutput):
batch = batch.logits
return super().before_uncollate(batch)
Expand Down
7 changes: 4 additions & 3 deletions flash/text/seq2seq/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities import rank_zero_info
from torch import Tensor
from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerBase

from flash.core import Task
Expand Down Expand Up @@ -83,13 +84,13 @@ def forward(self, x: Any) -> Any:
)
return generated_tokens

def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
def training_step(self, batch: Any, batch_idx: int) -> Tensor:
outputs = self.model(**batch)
loss = outputs[0]
self.log("train_loss", loss)
return loss

def common_step(self, prefix: str, batch: Any) -> torch.Tensor:
def common_step(self, prefix: str, batch: Any) -> Tensor:
generated_tokens = self.predict(batch, skip_collate_fn=True)
self.compute_metrics(generated_tokens, batch, prefix)

Expand Down Expand Up @@ -121,7 +122,7 @@ def _initialize_model_specific_parameters(self):
def tokenizer(self) -> PreTrainedTokenizerBase:
return self.data_pipeline.tokenizer

def tokenize_labels(self, labels: torch.Tensor) -> List[str]:
def tokenize_labels(self, labels: Tensor) -> List[str]:
label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
return [str.strip(s) for s in label_str]

Expand Down
4 changes: 2 additions & 2 deletions flash/text/seq2seq/summarization/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Dict, List, Tuple

import numpy as np
import torch
from rouge_score import rouge_scorer, scoring
from rouge_score.scoring import AggregateScore, Score
from torch import tensor
from torchmetrics import Metric

from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence
Expand Down Expand Up @@ -71,7 +71,7 @@ def update(self, pred_lns: List[str], tgt_lns: List[str]):
tgt = add_newline_to_end_of_each_sentence(tgt)
results = self.scorer.score(pred, tgt)
for key, score in results.items():
score = torch.tensor([score.precision, score.recall, score.fmeasure])
score = tensor([score.precision, score.recall, score.fmeasure])
getattr(self, key).append(score)

def compute(self) -> Dict[str, float]:
Expand Down
13 changes: 7 additions & 6 deletions flash/text/seq2seq/translation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import List

import torch
from torch import tensor
from torchmetrics import Metric


Expand Down Expand Up @@ -66,8 +67,8 @@ def __init__(self, n_gram: int = 4, smooth: bool = False):
self.n_gram = n_gram
self.smooth = smooth

self.add_state("c", torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("r", torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")

Expand All @@ -77,7 +78,7 @@ def compute(self):
ref_len = self.r.clone().detach()

if min(self.numerator) == 0.0:
return torch.tensor(0.0, device=self.r.device)
return tensor(0.0, device=self.r.device)

if self.smooth:
precision_scores = torch.add(self.numerator, torch.ones(
Expand All @@ -86,11 +87,11 @@ def compute(self):
else:
precision_scores = self.numerator / self.denominator

log_precision_scores = torch.tensor([1.0 / self.n_gram] * self.n_gram,
device=self.r.device) * torch.log(precision_scores)
log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram,
device=self.r.device) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = (
torch.tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len))
tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len))
)
bleu = brevity_penalty * geometric_mean
return bleu
Expand Down
4 changes: 2 additions & 2 deletions flash/vision/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from PIL import Image
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch import Tensor, tensor
from torch._six import container_abcs
from torch.utils.data._utils.collate import default_collate
from torchvision import transforms as T
Expand Down Expand Up @@ -88,7 +88,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
target = {}
target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
target["image_id"] = torch.tensor([img_idx])
target["image_id"] = tensor([img_idx])
target["area"] = torch.as_tensor(areas, dtype=torch.float32)
target["iscrowd"] = torch.as_tensor(iscrowd, dtype=torch.int64)

Expand Down
4 changes: 2 additions & 2 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torchvision
from torch import nn
from torch import nn, tensor
from torch.optim import Optimizer
from torchvision.models.detection.faster_rcnn import FasterRCNN, FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNet, RetinaNetHead
Expand All @@ -39,7 +39,7 @@ def _evaluate_iou(target, pred):
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


Expand Down
6 changes: 3 additions & 3 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F

from flash import ClassificationTask
Expand All @@ -33,7 +33,7 @@

class DummyDataset(torch.utils.data.Dataset):

def __getitem__(self, index: int) -> Tuple[torch.Tensor, Number]:
def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item()

def __len__(self) -> int:
Expand All @@ -42,7 +42,7 @@ def __len__(self) -> int:

class PredictDummyDataset(DummyDataset):

def __getitem__(self, index: int) -> torch.Tensor:
def __getitem__(self, index: int) -> Tensor:
return torch.rand(1, 28, 28)


Expand Down
39 changes: 20 additions & 19 deletions tests/data/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor, tensor
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

Expand All @@ -36,7 +37,7 @@

class DummyDataset(torch.utils.data.Dataset):

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
return torch.rand(1), torch.rand(1)

def __len__(self) -> int:
Expand Down Expand Up @@ -507,49 +508,49 @@ def train_pre_tensor_transform(self, sample: Any) -> Any:
self.train_pre_tensor_transform_called = True
return sample + (5, )

def train_collate(self, samples) -> torch.Tensor:
def train_collate(self, samples) -> Tensor:
self.train_collate_called = True
return torch.tensor([list(s) for s in samples])
return tensor([list(s) for s in samples])

def train_per_batch_transform_on_device(self, batch: Any) -> Any:
self.train_per_batch_transform_on_device_called = True
assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))
assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))

def val_load_data(self, sample, dataset) -> List[int]:
self.val_load_data_called = True
assert isinstance(dataset, AutoDataset)
return list(range(5))

def val_load_sample(self, sample) -> Dict[str, torch.Tensor]:
def val_load_sample(self, sample) -> Dict[str, Tensor]:
self.val_load_sample_called = True
return {"a": sample, "b": sample + 1}

def val_to_tensor_transform(self, sample: Any) -> torch.Tensor:
def val_to_tensor_transform(self, sample: Any) -> Tensor:
self.val_to_tensor_transform_called = True
return sample

def val_collate(self, samples) -> Dict[str, torch.Tensor]:
def val_collate(self, samples) -> Dict[str, Tensor]:
self.val_collate_called = True
_count = samples[0]['a']
assert samples == [{'a': _count, 'b': _count + 1}, {'a': _count + 1, 'b': _count + 2}]
return {'a': torch.tensor([0, 1]), 'b': torch.tensor([1, 2])}
return {'a': tensor([0, 1]), 'b': tensor([1, 2])}

def val_per_batch_transform_on_device(self, batch: Any) -> Any:
self.val_per_batch_transform_on_device_called = True
batch = batch[0]
assert torch.equal(batch["a"], torch.tensor([0, 1]))
assert torch.equal(batch["b"], torch.tensor([1, 2]))
assert torch.equal(batch["a"], tensor([0, 1]))
assert torch.equal(batch["b"], tensor([1, 2]))
return [False]

def test_load_data(self, sample) -> LamdaDummyDataset:
self.test_load_data_called = True
return LamdaDummyDataset(lambda: [torch.rand(1), torch.rand(1)])

def test_to_tensor_transform(self, sample: Any) -> torch.Tensor:
def test_to_tensor_transform(self, sample: Any) -> Tensor:
self.test_to_tensor_transform_called = True
return sample

def test_post_tensor_transform(self, sample: torch.Tensor) -> torch.Tensor:
def test_post_tensor_transform(self, sample: Tensor) -> Tensor:
self.test_post_tensor_transform_called = True
return sample

Expand All @@ -560,9 +561,9 @@ def predict_load_data(self, sample) -> LamdaDummyDataset:

class TestPreprocessTransformations2(TestPreprocessTransformations):

def val_to_tensor_transform(self, sample: Any) -> torch.Tensor:
def val_to_tensor_transform(self, sample: Any) -> Tensor:
self.val_to_tensor_transform_called = True
return {"a": torch.tensor(sample["a"]), "b": torch.tensor(sample["b"])}
return {"a": tensor(sample["a"]), "b": tensor(sample["b"])}


@pytest.mark.skipif(reason="Still using DataPipeline Old API")
Expand All @@ -585,7 +586,7 @@ def test_step(self, batch, batch_idx):

def predict_step(self, batch, batch_idx, dataloader_idx):
assert batch == [('a', 'a'), ('b', 'b')]
return torch.tensor([0, 0, 0])
return tensor([0, 0, 0])

class CustomDataModule(DataModule):

Expand All @@ -595,7 +596,7 @@ class CustomDataModule(DataModule):

assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3)
batch = next(iter(datamodule.train_dataloader()))
assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))
assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))

assert datamodule.val_dataloader().dataset[0] == {'a': 0, 'b': 1}
assert datamodule.val_dataloader().dataset[1] == {'a': 1, 'b': 2}
Expand All @@ -605,8 +606,8 @@ class CustomDataModule(DataModule):
CustomDataModule.preprocess_cls = TestPreprocessTransformations2
datamodule = CustomDataModule.from_load_data_inputs(1, 1, 1, 1, batch_size=2)
batch = next(iter(datamodule.val_dataloader()))
assert torch.equal(batch["a"], torch.tensor([0, 1]))
assert torch.equal(batch["b"], torch.tensor([1, 2]))
assert torch.equal(batch["a"], tensor([0, 1]))
assert torch.equal(batch["b"], tensor([1, 2]))

model = CustomModel()
trainer = Trainer(
Expand Down Expand Up @@ -679,7 +680,7 @@ def load_sample(self, path: str) -> Image.Image:
img8Bit = np.uint8(np.random.uniform(0, 1, (64, 64, 3)) * 255.0)
return Image.fromarray(img8Bit)

def to_tensor_transform(self, pil_image: Image.Image) -> torch.Tensor:
def to_tensor_transform(self, pil_image: Image.Image) -> Tensor:
# convert pil image into a tensor
return self._to_tensor(pil_image)

Expand Down

0 comments on commit 0109236

Please sign in to comment.