Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Binary spec inherits from discrete spec #984

Merged
merged 2 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,7 @@ def test_to_numpy(self, shape, stack_dim):
c1 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float64)
c2 = BoundedTensorSpec(-1, 1, shape=shape, dtype=torch.float32)
c = torch.stack([c1, c2], stack_dim)
torch.manual_seed(0)

shape = list(shape)
shape.insert(stack_dim, 2)
Expand Down
207 changes: 86 additions & 121 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,127 +1357,6 @@ def expand(self, *shape):
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)


@dataclass(repr=False)
class BinaryDiscreteTensorSpec(TensorSpec):
"""A binary discrete tensor spec.

Args:
n (int): length of the binary vector.
shape (torch.Size, optional): total shape of the sampled tensors.
If provided, the last dimension must match n.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.

Examples:
>>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
>>> print(spec.zero())
"""

shape: torch.Size
space: BinaryBox
device: torch.device = torch.device("cpu")
dtype: torch.dtype = torch.float
domain: str = ""

# SPEC_HANDLED_FUNCTIONS = {}

def __init__(
self,
n: int,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.long,
):
dtype, device = _default_dtype_and_device(dtype, device)
box = BinaryBox(n)
if shape is None or not len(shape):
shape = torch.Size((n,))
else:
shape = torch.Size(shape)
if shape[-1] != box.n:
raise ValueError(
f"The last value of the shape must match n for transform of type {self.__class__}. "
f"Got n={box.n} and shape={shape}."
)

super().__init__(shape, box, device, dtype, domain="discrete")

def rand(self, shape=None) -> torch.Tensor:
if shape is None:
shape = torch.Size([])
shape = [*shape, *self.shape]
return torch.zeros(shape, device=self.device, dtype=self.dtype).bernoulli_()

def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Tensor:
if not isinstance(index, torch.Tensor):
raise ValueError(
f"Only tensors are allowed for indexing using"
f" {self.__class__.__name__}.index(...)"
)
index = index.nonzero().squeeze()
index = index.expand((*tensor_to_index.shape[:-1], index.shape[-1]))
return tensor_to_index.gather(-1, index)

def is_in(self, val: torch.Tensor) -> bool:
return ((val == 0) | (val == 1)).all()

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
if any(val < 0 for val in shape):
raise ValueError(
f"{self.__class__.__name__}.extend does not support negative shapes."
)
if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
raise ValueError(
f"The last {self.ndim} of the extended shape must match the"
f"shape of the CompositeSpec in CompositeSpec.extend."
)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def squeeze(self, dim: int | None = None):
if self.shape[-1] == 1 and dim in (len(self.shape), -1, None):
raise ValueError(
"Final dimension of BinaryDiscreteTensorSpec must remain unchanged"
)
shape = _squeezed_shape(self.shape, dim)
if shape is None:
return self
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def unsqueeze(self, dim: int):
if dim in (len(self.shape), -1):
raise ValueError(
"Final dimension of BinaryDiscreteTensorSpec must remain unchanged"
)
shape = _unsqueezed_shape(self.shape, dim)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(
n=self.space.n, shape=self.shape, device=dest_device, dtype=dest_dtype
)

def clone(self) -> CompositeSpec:
return self.__class__(
n=self.space.n, shape=self.shape, device=self.device, dtype=self.dtype
)


@dataclass(repr=False)
class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec):
"""A concatenation of one-hot discrete tensor spec.
Expand Down Expand Up @@ -1844,6 +1723,92 @@ def clone(self) -> CompositeSpec:
)


@dataclass(repr=False)
class BinaryDiscreteTensorSpec(DiscreteTensorSpec):
"""A binary discrete tensor spec.

Args:
n (int): length of the binary vector.
shape (torch.Size, optional): total shape of the sampled tensors.
If provided, the last dimension must match n.
device (str, int or torch.device, optional): device of the tensors.
dtype (str or torch.dtype, optional): dtype of the tensors. Defaults to torch.long.

Examples:
>>> spec = BinaryDiscreteTensorSpec(n=4, shape=(5, 4), device="cpu", dtype=torch.bool)
>>> print(spec.zero())
"""

def __init__(
self,
n: int,
shape: Optional[torch.Size] = None,
device: Optional[DEVICE_TYPING] = None,
dtype: Union[str, torch.dtype] = torch.long,
):
if shape is None or not len(shape):
shape = torch.Size((n,))
else:
shape = torch.Size(shape)
if shape[-1] != n:
raise ValueError(
f"The last value of the shape must match n for spec {self.__class__}. "
f"Got n={n} and shape={shape}."
)
super().__init__(n=2, shape=shape, device=device, dtype=dtype)

def expand(self, *shape):
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
if any(val < 0 for val in shape):
raise ValueError(
f"{self.__class__.__name__}.extend does not support negative shapes."
)
if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
raise ValueError(
f"The last {self.ndim} of the extended shape must match the"
f"shape of the CompositeSpec in CompositeSpec.extend."
)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def squeeze(self, dim=None):
shape = _squeezed_shape(self.shape, dim)
if shape is None:
return self
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def unsqueeze(self, dim: int):
shape = _unsqueezed_shape(self.shape, dim)
return self.__class__(
n=shape[-1], shape=shape, device=self.device, dtype=self.dtype
)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec:
if isinstance(dest, torch.dtype):
dest_dtype = dest
dest_device = self.device
else:
dest_dtype = self.dtype
dest_device = torch.device(dest)
if dest_device == self.device and dest_dtype == self.dtype:
return self
return self.__class__(
n=self.shape[-1], shape=self.shape, device=dest_device, dtype=dest_dtype
)

def clone(self) -> CompositeSpec:
return self.__class__(
n=self.shape[-1],
shape=self.shape,
device=self.device,
dtype=self.dtype,
)


@dataclass(repr=False)
class MultiDiscreteTensorSpec(DiscreteTensorSpec):
"""A concatenation of discrete tensor spec.
Expand Down