From b6e54e874af379e3d93694853cb24602b2a21660 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Mar 2022 17:27:09 +0000 Subject: [PATCH] Fix support for torch 11 (#1234) --- CHANGELOG.md | 2 ++ flash/core/serve/types/bbox.py | 3 +-- flash/text/question_answering/input.py | 2 +- flash/text/seq2seq/core/input.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 408e7be5da..8686112abd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where buffers in loss functions were not correctly registered in the `Task` ([#1203](https://github.com/PyTorchLightning/lightning-flash/pull/1203)) - Fixed support for passing a sampler instance to `from_*` methods / the `DataModule` ([#1204](https://github.com/PyTorchLightning/lightning-flash/pull/1204)) +- Fixed support for `torch==1.11.0` ([#1234](https://github.com/PyTorchLightning/lightning-flash/pull/1234)) + ## [0.7.0] - 2022-02-15 ### Added diff --git a/flash/core/serve/types/bbox.py b/flash/core/serve/types/bbox.py index 5f8b01951b..73881f0127 100644 --- a/flash/core/serve/types/bbox.py +++ b/flash/core/serve/types/bbox.py @@ -20,12 +20,11 @@ class BBox(BaseType): def __post_init__(self): self._valid_size = torch.Size([4]) - self._invalid_types = {torch.bool, torch.complex32, torch.complex64, torch.complex128} def _validate(self, elem): if elem.shape != self._valid_size: raise ValueError("Each box must consist of (only) four elements each " "corresponding to x1, x2, y1 and y2") - if elem.dtype in self._invalid_types: + if elem.dtype == torch.bool or torch.is_complex(elem): raise TypeError(f"Found unsupported datatype for " f"bounding boxes: {elem.dtype}") def deserialize(self, box: Tuple[float, ...]) -> torch.Tensor: diff --git a/flash/text/question_answering/input.py b/flash/text/question_answering/input.py index 541c451e31..b1d18c371d 100644 --- a/flash/text/question_answering/input.py +++ b/flash/text/question_answering/input.py @@ -77,7 +77,7 @@ def load_data( if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset - hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)] + hf_dataset = [sample for sample in hf_dataset][:40] return hf_dataset diff --git a/flash/text/seq2seq/core/input.py b/flash/text/seq2seq/core/input.py index 0f935d49aa..88d14adde6 100644 --- a/flash/text/seq2seq/core/input.py +++ b/flash/text/seq2seq/core/input.py @@ -44,7 +44,7 @@ def load_data( if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset - hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)] + hf_dataset = [sample for sample in hf_dataset][:40] return hf_dataset