diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 81116fbcacaf..3f7f559b6e2c 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -39,8 +39,8 @@ jnp = jax.numpy # Represents a generic tensor type. -# This could be an np.ndarray, tf.Tensor, or a torch.Tensor. -TensorType = Union[np.array, "jnp.ndarray", "tf.Tensor", "torch.Tensor"] +# This could be an np.ndarray, jnp.ndarray, tf.Tensor, or a torch.Tensor. +TensorType = Union[np.ndarray, "jnp.ndarray", "tf.Tensor", "torch.Tensor"] # Either a plain tensor, or a dict or tuple of tensors (or StructTensors). TensorStructType = Union[TensorType, dict, tuple]