Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Discrete(dtype) parameter #1197

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions gymnasium/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Discrete(Space[np.int64]):
def __init__(
self,
n: int | np.integer[Any],
dtype: str | type[np.integer[Any]] = np.int64,
seed: int | np.random.Generator | None = None,
start: int | np.integer[Any] = 0,
):
Expand All @@ -36,6 +37,7 @@ def __init__(

Args:
n (int): The number of elements of this space.
dtype: This should be some kind of integer type.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
start (int): The smallest element of this space.
"""
Expand All @@ -47,16 +49,29 @@ def __init__(
type(start), np.integer
), f"Expects `start` to be an integer, actual type: {type(start)}"

self.n = np.int64(n)
self.start = np.int64(start)
super().__init__((), np.int64, seed)
# determine dtype
if dtype is None:
raise ValueError(
"Discrete dtype must be explicitly provided, cannot be None."
)
self.dtype = np.dtype(dtype)

# * check that dtype is an accepted dtype
if not (np.issubdtype(self.dtype, np.integer)):
raise ValueError(
f"Invalid Discrete dtype ({self.dtype}), must be an integer dtype"
)

self.n = self.dtype.type(n)
self.start = self.dtype.type(start)
super().__init__((), self.dtype, seed)

@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> np.int64:
def sample(self, mask: MaskNDArray | None = None) -> np.integer[Any]:
"""Generates a single random sample from this space.

A sample will be chosen uniformly at random with the mask if provided
Expand Down Expand Up @@ -84,13 +99,13 @@ def sample(self, mask: MaskNDArray | None = None) -> np.int64:
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return self.start + self.np_random.choice(
JDRanpariya marked this conversation as resolved.
Show resolved Hide resolved
np.where(valid_action_mask)[0]
return self.start + self.dtype.type(
self.np_random.choice(np.where(valid_action_mask)[0])
)
else:
return self.start

return self.start + self.np_random.integers(self.n)
return self.start + self.dtype.type(self.np_random.integers(self.n))
JDRanpariya marked this conversation as resolved.
Show resolved Hide resolved

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down Expand Up @@ -137,7 +152,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):

super().__setstate__(state)

def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]:
def to_jsonable(self, sample_n: Sequence[np.integer[Any]]) -> list[int]:
"""Converts a list of samples to a list of ints."""
return [int(x) for x in sample_n]

Expand Down
Loading