Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix support for torch 11 (#1234)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Mar 30, 2022
1 parent b9e1af1 commit eb4c990
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where buffers in loss functions were not correctly registered in the `Task` ([#1203](https://github.com/PyTorchLightning/lightning-flash/pull/1203))
- Fixed support for passing a sampler instance to `from_*` methods / the `DataModule` ([#1204](https://github.com/PyTorchLightning/lightning-flash/pull/1204))

- Fixed support for `torch==1.11.0` ([#1234](https://github.com/PyTorchLightning/lightning-flash/pull/1234))

## [0.7.0] - 2022-02-15

### Added
Expand Down
3 changes: 1 addition & 2 deletions flash/core/serve/types/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ class BBox(BaseType):

def __post_init__(self):
self._valid_size = torch.Size([4])
self._invalid_types = {torch.bool, torch.complex32, torch.complex64, torch.complex128}

def _validate(self, elem):
if elem.shape != self._valid_size:
raise ValueError("Each box must consist of (only) four elements each " "corresponding to x1, x2, y1 and y2")
if elem.dtype in self._invalid_types:
if elem.dtype == torch.bool or torch.is_complex(elem):
raise TypeError(f"Found unsupported datatype for " f"bounding boxes: {elem.dtype}")

def deserialize(self, box: Tuple[float, ...]) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion flash/text/question_answering/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_data(

if flash._IS_TESTING:
# NOTE: must subset in this way to return a Dataset
hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)]
hf_dataset = [sample for sample in hf_dataset][:40]

return hf_dataset

Expand Down
2 changes: 1 addition & 1 deletion flash/text/seq2seq/core/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_data(

if flash._IS_TESTING:
# NOTE: must subset in this way to return a Dataset
hf_dataset = [sample for sample in hf_dataset.select(range(40), keep_in_memory=True)]
hf_dataset = [sample for sample in hf_dataset][:40]

return hf_dataset

Expand Down

0 comments on commit eb4c990

Please sign in to comment.