Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix support for torch 11
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Mar 15, 2022
1 parent 6aebfbe commit 737cd33
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions flash/core/serve/types/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ class BBox(BaseType):

def __post_init__(self):
self._valid_size = torch.Size([4])
self._invalid_types = {torch.bool, torch.complex32, torch.complex64, torch.complex128}

def _validate(self, elem):
if elem.shape != self._valid_size:
raise ValueError("Each box must consist of (only) four elements each " "corresponding to x1, x2, y1 and y2")
if elem.dtype in self._invalid_types:
if elem.dtype == torch.bool or torch.is_complex(elem):
raise TypeError(f"Found unsupported datatype for " f"bounding boxes: {elem.dtype}")

def deserialize(self, box: Tuple[float, ...]) -> torch.Tensor:
Expand Down

0 comments on commit 737cd33

Please sign in to comment.