diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 44217f5bb..db0d32d42 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -89,7 +89,7 @@ def __init__( if shape is not None: assert all( np.issubdtype(type(dim), np.integer) for dim in shape - ), f"Expect all shape elements to be an integer, actual type: {tuple(type(dim) for dim in shape)}" + ), f"Expected all shape elements to be an integer, actual type: {tuple(type(dim) for dim in shape)}" shape = tuple(int(dim) for dim in shape) # This changes any np types to int elif isinstance(low, np.ndarray): shape = low.shape @@ -99,7 +99,7 @@ def __init__( shape = (1,) else: raise ValueError( - f"Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: {type(low)}, high: {type(high)}" + f"Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: {type(low)}, high: {type(high)}" ) # Capture the boundedness information before replacing np.inf with get_inf @@ -109,8 +109,8 @@ def __init__( _high = np.full(shape, high, dtype=float) if is_float_integer(high) else high self.bounded_above: NDArray[np.bool_] = np.inf > _high - low = _broadcast(low, self.dtype, shape, inf_sign="-") - high = _broadcast(high, self.dtype, shape, inf_sign="+") + low = _broadcast(low, self.dtype, shape) + high = _broadcast(high, self.dtype, shape) assert isinstance(low, np.ndarray) assert ( @@ -121,6 +121,16 @@ def __init__( high.shape == shape ), f"high.shape doesn't match provided shape, high.shape: {high.shape}, shape: {shape}" + # check that we don't have invalid low or high + if np.any(low > high): + raise ValueError( + f"Some low values are greater than high, low={low}, high={high}" + ) + if np.any(np.isposinf(low)): + raise ValueError(f"No low value can be equal to `np.inf`, low={low}") + if np.any(np.isneginf(high)): + raise ValueError(f"No high value can be equal to `-np.inf`, high={high}") + self._shape: tuple[int, ...] = shape low_precision = get_precision(low.dtype) @@ -281,38 +291,6 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]): self.high_repr = _short_repr(self.high) -def get_inf(dtype: np.dtype, sign: str) -> int | float: - """Returns an infinite that doesn't break things. - - Args: - dtype: An `np.dtype` - sign (str): must be either `"+"` or `"-"` - - Returns: - Gets an infinite value with the sign and dtype - - Raises: - TypeError: Unknown sign, use either '+' or '-' - ValueError: Unknown dtype for infinite bounds - """ - if np.dtype(dtype).kind == "f": - if sign == "+": - return np.inf - elif sign == "-": - return -np.inf - else: - raise TypeError(f"Unknown sign {sign}, use either '+' or '-'") - elif np.dtype(dtype).kind == "i": - if sign == "+": - return np.iinfo(dtype).max - 2 - elif sign == "-": - return np.iinfo(dtype).min + 2 - else: - raise TypeError(f"Unknown sign {sign}, use either '+' or '-'") - else: - raise ValueError(f"Unknown dtype {dtype} for infinite bounds") - - def get_precision(dtype: np.dtype) -> SupportsFloat: """Get precision of a data type.""" if np.issubdtype(dtype, np.floating): @@ -325,17 +303,35 @@ def _broadcast( value: SupportsFloat | NDArray[Any], dtype: np.dtype, shape: tuple[int, ...], - inf_sign: str, ) -> NDArray[Any]: - """Handle infinite bounds and broadcast at the same time if needed.""" + """Handle infinite bounds and broadcast at the same time if needed. + + This is needed primarily because: + >>> import numpy as np + >>> np.full((2,), np.inf, dtype=np.int32) + array([-2147483648, -2147483648], dtype=int32) + """ if is_float_integer(value): - value = get_inf(dtype, inf_sign) if np.isinf(value) else value - value = np.full(shape, value, dtype=dtype) + if np.isneginf(value) and np.dtype(dtype).kind == "i": + value = np.iinfo(dtype).min + 2 + elif np.isposinf(value) and np.dtype(dtype).kind == "i": + value = np.iinfo(dtype).max - 2 + + return np.full(shape, value, dtype=dtype) + + elif isinstance(value, np.ndarray): + # this is needed because we can't stuff np.iinfo(int).min into an array of dtype float + casted_value = value.astype(dtype) + + # change bounds only if values are negative or positive infinite + if np.dtype(dtype).kind == "i": + casted_value[np.isneginf(value)] = np.iinfo(dtype).min + 2 + casted_value[np.isposinf(value)] = np.iinfo(dtype).max - 2 + + return casted_value + else: - assert isinstance(value, np.ndarray) - if np.any(np.isinf(value)): - # create new array with dtype, but maintain old one to preserve np.inf - temp = value.astype(dtype) - temp[np.isinf(value)] = get_inf(dtype, inf_sign) - value = temp - return value + # only np.ndarray allowed beyond this point + raise TypeError( + f"Unknown dtype for `value`, expected `np.ndarray` or float/integer, got {type(value)}" + ) diff --git a/gymnasium/utils/passive_env_checker.py b/gymnasium/utils/passive_env_checker.py index 402a24dc3..ff448536a 100644 --- a/gymnasium/utils/passive_env_checker.py +++ b/gymnasium/utils/passive_env_checker.py @@ -58,11 +58,6 @@ def _check_box_action_space(action_space: spaces.Box): "A Box action space maximum and minimum values are equal. " f"Actual equal coordinates: {[x for x in zip(*np.where(action_space.low == action_space.high))]}" ) - elif np.any(action_space.high < action_space.low): - logger.warn( - "A Box action space low value is greater than a high value. " - f"Actual less than coordinates: {[x for x in zip(*np.where(action_space.high < action_space.low))]}" - ) def check_space( diff --git a/tests/spaces/test_box.py b/tests/spaces/test_box.py index 24a07d2e8..85dab15ab 100644 --- a/tests/spaces/test_box.py +++ b/tests/spaces/test_box.py @@ -6,7 +6,6 @@ import gymnasium as gym from gymnasium.spaces import Box -from gymnasium.spaces.box import get_inf @pytest.mark.parametrize( @@ -61,7 +60,7 @@ def test_low_high_values(value, valid: bool): """Test what `low` and `high` values are valid for `Box` space.""" if valid: with warnings.catch_warnings(record=True) as caught_warnings: - Box(low=value, high=value) + Box(low=-np.inf, high=value) assert len(caught_warnings) == 0, tuple( warning.message for warning in caught_warnings ) @@ -69,10 +68,10 @@ def test_low_high_values(value, valid: bool): with pytest.raises( ValueError, match=re.escape( - "expect their types to be np.ndarray, an integer or a float" + "expected their types to be np.ndarray, an integer or a float" ), ): - Box(low=value, high=value) + Box(low=-np.inf, high=value) @pytest.mark.parametrize( @@ -90,7 +89,7 @@ def test_low_high_values(value, valid: bool): 1, {"shape": (None,)}, AssertionError, - "Expect all shape elements to be an integer, actual type: (,)", + "Expected all shape elements to be an integer, actual type: (,)", ), ( 0, @@ -102,7 +101,7 @@ def test_low_high_values(value, valid: bool): ) }, AssertionError, - "Expect all shape elements to be an integer, actual type: (, )", + "Expected all shape elements to be an integer, actual type: (, )", ), ( 0, @@ -114,21 +113,21 @@ def test_low_high_values(value, valid: bool): ) }, AssertionError, - "Expect all shape elements to be an integer, actual type: (, )", + "Expected all shape elements to be an integer, actual type: (, )", ), ( None, None, {}, ValueError, - "Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: , high: ", + "Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: , high: ", ), ( 0, None, {}, ValueError, - "Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: , high: ", + "Box shape is inferred from low and high, expected their types to be np.ndarray, an integer or a float, actual type low: , high: ", ), ( np.zeros(3), @@ -283,29 +282,6 @@ def test_legacy_state_pickling(): assert b.high_repr == "1.0" -def test_get_inf(): - """Tests that get inf function works as expected, primarily for coverage.""" - assert get_inf(np.float32, "+") == np.inf - assert get_inf(np.float16, "-") == -np.inf - with pytest.raises( - TypeError, match=re.escape("Unknown sign *, use either '+' or '-'") - ): - get_inf(np.float32, "*") - - assert get_inf(np.int16, "+") == 32765 - assert get_inf(np.int8, "-") == -126 - with pytest.raises( - TypeError, match=re.escape("Unknown sign *, use either '+' or '-'") - ): - get_inf(np.int32, "*") - - with pytest.raises( - ValueError, - match=re.escape("Unknown dtype for infinite bounds"), - ): - get_inf(np.complex_, "+") - - def test_sample_mask(): """Box cannot have a mask applied.""" space = Box(0, 1) @@ -314,3 +290,56 @@ def test_sample_mask(): match=re.escape("Box.sample cannot be provided a mask, actual value: "), ): space.sample(mask=np.array([0, 1, 0], dtype=np.int8)) + + +@pytest.mark.parametrize( + "low, high, shape, dtype, reason", + [ + ( + 5.0, + 3.0, + (), + np.float32, + "Some low values are greater than high, low=5.0, high=3.0", + ), + ( + np.array([5.0, 6.0]), + np.array([1.0, 5.99]), + (2,), + np.float32, + "Some low values are greater than high, low=[5. 6.], high=[1. 5.99]", + ), + ( + np.inf, + np.inf, + (), + np.float32, + "No low value can be equal to `np.inf`, low=inf", + ), + ( + np.array([0, np.inf]), + np.array([np.inf, np.inf]), + (2,), + np.float32, + "No low value can be equal to `np.inf`, low=[ 0. inf]", + ), + ( + -np.inf, + -np.inf, + (), + np.float32, + "No high value can be equal to `-np.inf`, high=-inf", + ), + ( + np.array([-np.inf, -np.inf]), + np.array([0, -np.inf]), + (2,), + np.float32, + "No high value can be equal to `-np.inf`, high=[ 0. -inf]", + ), + ], +) +def test_invalid_low_high(low, high, dtype, shape, reason): + """Tests that we don't allow spaces with degenerate bounds, such as `Box(np.inf, -np.inf)`.""" + with pytest.raises(ValueError, match=re.escape(reason)): + Box(low=low, high=high, dtype=dtype, shape=shape) diff --git a/tests/utils/test_passive_env_checker.py b/tests/utils/test_passive_env_checker.py index c3685b1e8..d19f81ae5 100644 --- a/tests/utils/test_passive_env_checker.py +++ b/tests/utils/test_passive_env_checker.py @@ -37,11 +37,6 @@ def _modify_space(space: spaces.Space, attribute: str, value): spaces.Box(np.zeros(5), np.zeros(5)), "A Box observation space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)]", ], - [ - UserWarning, - spaces.Box(np.ones(5), np.zeros(5)), - "A Box observation space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)]", - ], [ AssertionError, _modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)), @@ -113,11 +108,6 @@ def test_check_observation_space(test, space, message: str): spaces.Box(np.zeros(5), np.zeros(5)), "A Box action space maximum and minimum values are equal. Actual equal coordinates: [(0,), (1,), (2,), (3,), (4,)]", ], - [ - UserWarning, - spaces.Box(np.ones(5), np.zeros(5)), - "A Box action space low value is greater than a high value. Actual less than coordinates: [(0,), (1,), (2,), (3,), (4,)]", - ], [ AssertionError, _modify_space(spaces.Box(np.zeros(2), np.ones(2)), "low", np.zeros(3)),