Skip to content

Commit

Permalink
[Feature] Contiguous stacking of matching specs (#960)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcbegley authored Mar 13, 2023
1 parent a912a2e commit c1acefd
Show file tree
Hide file tree
Showing 2 changed files with 388 additions and 77 deletions.
106 changes: 59 additions & 47 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
CompositeSpec,
DiscreteTensorSpec,
LazyStackedCompositeSpec,
LazyStackedTensorSpec,
MultiDiscreteTensorSpec,
MultiOneHotDiscreteTensorSpec,
OneHotDiscreteTensorSpec,
Expand Down Expand Up @@ -1716,7 +1715,7 @@ def test_stack_binarydiscrete(self, shape, stack_dim):
c1 = BinaryDiscreteTensorSpec(n=n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, BinaryDiscreteTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -1761,7 +1760,7 @@ def test_stack_bounded(self, shape, stack_dim):
c1 = BoundedTensorSpec(mini, maxi, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, BoundedTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -1808,7 +1807,7 @@ def test_stack_discrete(self, shape, stack_dim):
c1 = DiscreteTensorSpec(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, DiscreteTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -1852,7 +1851,7 @@ def test_stack_multidiscrete(self, shape, stack_dim):
c1 = MultiDiscreteTensorSpec(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, MultiDiscreteTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -1896,7 +1895,7 @@ def test_stack_multionehot(self, shape, stack_dim):
c1 = MultiOneHotDiscreteTensorSpec(nvec, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, MultiOneHotDiscreteTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -1940,7 +1939,7 @@ def test_stack_onehot(self, shape, stack_dim):
c1 = OneHotDiscreteTensorSpec(n, shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, OneHotDiscreteTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -1983,7 +1982,7 @@ def test_stack_unboundedcont(self, shape, stack_dim):
c1 = UnboundedContinuousTensorSpec(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, UnboundedContinuousTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -2023,7 +2022,7 @@ def test_stack_unboundeddiscrete(self, shape, stack_dim):
c1 = UnboundedDiscreteTensorSpec(shape=shape)
c2 = c1.clone()
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedTensorSpec)
assert isinstance(c, UnboundedDiscreteTensorSpec)
shape = list(shape)
if stack_dim < 0:
stack_dim = len(shape) + stack_dim + 1
Expand Down Expand Up @@ -2064,11 +2063,13 @@ def test_stack(self):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec())
c2 = c1.clone()
c = torch.stack([c1, c2], 0)
assert isinstance(c, LazyStackedCompositeSpec)
assert isinstance(c, CompositeSpec)

def test_stack_index(self):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec())
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(), b=UnboundedDiscreteTensorSpec()
)
c = torch.stack([c1, c2], 0)
assert c.shape == torch.Size([2])
assert c[0] is c1
Expand All @@ -2082,7 +2083,11 @@ def test_stack_index(self):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_index_multdim(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
if stack_dim in (0, -3):
assert isinstance(c[:], LazyStackedCompositeSpec)
Expand Down Expand Up @@ -2146,36 +2151,14 @@ def test_stack_index_multdim(self, stack_dim):
assert c[:, :, 0, ...] is c1
assert c[:, :, 1, ...] is c2

@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_expand_one(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c = torch.stack([c1], stack_dim)
if stack_dim in (0, -3):
c_expand = c.expand([4, 2, 1, 3])
assert c_expand.shape == torch.Size([4, 2, 1, 3])
assert c_expand.dim == 1
elif stack_dim in (1, -2):
c_expand = c.expand([4, 1, 2, 3])
assert c_expand.shape == torch.Size([4, 1, 2, 3])
assert c_expand.dim == 2
elif stack_dim in (2, -1):
c_expand = c.expand(
[
4,
1,
3,
2,
]
)
assert c_expand.shape == torch.Size([4, 1, 3, 2])
assert c_expand.dim == 3
else:
raise NotImplementedError

@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_expand_multi(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
if stack_dim in (0, -3):
c_expand = c.expand([4, 2, 1, 3])
Expand All @@ -2202,7 +2185,11 @@ def test_stack_expand_multi(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_rand(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
r = c.rand()
assert isinstance(r, LazyStackedTensorDict)
Expand All @@ -2220,7 +2207,11 @@ def test_stack_rand(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_rand_shape(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
shape = [5, 6]
r = c.rand(shape)
Expand All @@ -2239,7 +2230,11 @@ def test_stack_rand_shape(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_zero(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
r = c.zero()
assert isinstance(r, LazyStackedTensorDict)
Expand All @@ -2257,7 +2252,11 @@ def test_stack_zero(self, stack_dim):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_stack_zero_shape(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
shape = [5, 6]
r = c.zero(shape)
Expand All @@ -2274,18 +2273,31 @@ def test_stack_zero_shape(self, stack_dim):
assert (r["a"] == 0).all()

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
def test_to(self):
@pytest.mark.parametrize("stack_dim", [0, 1, 2, -3, -2, -1])
def test_to(self, stack_dim):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], stack_dim)
assert isinstance(c, LazyStackedCompositeSpec)
cdevice = c.to("cuda:0")
assert cdevice.device != c.device
assert cdevice.device == torch.device("cuda:0")
assert cdevice[0].device == torch.device("cuda:0")
if stack_dim < 0:
stack_dim += 3
index = (slice(None),) * stack_dim + (0,)
assert cdevice[index].device == torch.device("cuda:0")

def test_clone(self):
c1 = CompositeSpec(a=UnboundedContinuousTensorSpec(shape=(1, 3)), shape=(1, 3))
c2 = c1.clone()
c2 = CompositeSpec(
a=UnboundedContinuousTensorSpec(shape=(1, 3)),
b=UnboundedDiscreteTensorSpec(shape=(1, 3)),
shape=(1, 3),
)
c = torch.stack([c1, c2], 0)
cclone = c.clone()
assert cclone[0] is not c[0]
Expand Down
Loading

0 comments on commit c1acefd

Please sign in to comment.