diff --git a/gymnasium/spaces/sequence.py b/gymnasium/spaces/sequence.py index 9c19942f6..72be364e1 100644 --- a/gymnasium/spaces/sequence.py +++ b/gymnasium/spaces/sequence.py @@ -119,7 +119,8 @@ def sample( def contains(self, x: Any) -> bool: """Return boolean specifying if x is a valid member of this space.""" - return isinstance(x, collections.abc.Sequence) and all( + # by definition, any sequence is an iterable + return isinstance(x, collections.abc.Iterable) and all( self.feature_space.contains(item) for item in x ) diff --git a/tests/spaces/test_sequence.py b/tests/spaces/test_sequence.py index 110cd6bdc..cb29c5b51 100644 --- a/tests/spaces/test_sequence.py +++ b/tests/spaces/test_sequence.py @@ -6,6 +6,15 @@ import gymnasium as gym +def test_stacked_box(): + """Tests that sequence with a feature space of Box allows stacked np arrays.""" + space = gym.spaces.Sequence(gym.spaces.Box(0, 1, shape=(3,))) + sample = np.float32(np.random.rand(5, 3)) + assert space.contains( + sample + ), "Something went wrong, should be able to accept stacked np arrays for Box feature space." + + def test_sample(): """Tests the sequence sampling works as expects and the errors are correctly raised.""" space = gym.spaces.Sequence(gym.spaces.Box(0, 1))