Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Discrete(dtype) parameter #1197

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions gymnasium/spaces/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Discrete(Space[np.int64]):
def __init__(
self,
n: int | np.integer[Any],
dtype: str | type[np.integer[Any]] = np.int64,
seed: int | np.random.Generator | None = None,
start: int | np.integer[Any] = 0,
):
Expand All @@ -36,6 +37,7 @@ def __init__(

Args:
n (int): The number of elements of this space.
dtype: This should be some kind of integer type.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the ``Dict`` space.
start (int): The smallest element of this space.
"""
Expand All @@ -47,16 +49,29 @@ def __init__(
type(start), np.integer
), f"Expects `start` to be an integer, actual type: {type(start)}"

self.n = np.int64(n)
self.start = np.int64(start)
super().__init__((), np.int64, seed)
# determine dtype
if dtype is None:
raise ValueError(
"Discrete dtype must be explicitly provided, cannot be None."
)
self.dtype = np.dtype(dtype)

# * check that dtype is an accepted dtype
if not (np.issubdtype(self.dtype, np.integer)):
raise ValueError(
f"Invalid Discrete dtype ({self.dtype}), must be an integer dtype"
)

self.n = self.dtype.type(n)
self.start = self.dtype.type(start)
super().__init__((), self.dtype, seed)

@property
def is_np_flattenable(self):
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
return True

def sample(self, mask: MaskNDArray | None = None) -> np.int64:
def sample(self, mask: MaskNDArray | None = None) -> np.integer[Any]:
"""Generates a single random sample from this space.

A sample will be chosen uniformly at random with the mask if provided
Expand Down Expand Up @@ -84,13 +99,13 @@ def sample(self, mask: MaskNDArray | None = None) -> np.int64:
np.logical_or(mask == 0, valid_action_mask)
), f"All values of a mask should be 0 or 1, actual values: {mask}"
if np.any(valid_action_mask):
return self.start + self.np_random.choice(
JDRanpariya marked this conversation as resolved.
Show resolved Hide resolved
np.where(valid_action_mask)[0]
return self.start + self.dtype.type(
self.np_random.choice(np.where(valid_action_mask)[0])
)
else:
return self.start

return self.start + self.np_random.integers(self.n)
return self.start + self.np_random.integers(self.n).astype(self.dtype)

def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
Expand Down Expand Up @@ -137,7 +152,7 @@ def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):

super().__setstate__(state)

def to_jsonable(self, sample_n: Sequence[np.int64]) -> list[int]:
def to_jsonable(self, sample_n: Sequence[np.integer[Any]]) -> list[int]:
"""Converts a list of samples to a list of ints."""
return [int(x) for x in sample_n]

Expand Down
34 changes: 34 additions & 0 deletions tests/spaces/test_discrete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy

import numpy as np
import pytest

from gymnasium.spaces import Discrete

Expand Down Expand Up @@ -32,3 +33,36 @@ def test_sample_mask():
assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]


@pytest.mark.parametrize(
"dtype, sample_dtype",
[
(int, np.int64),
(np.int64, np.int64),
(np.int32, np.int32),
(np.uint8, np.uint8),
],
)
def test_dtype(dtype, sample_dtype):
space = Discrete(n=5, dtype=dtype, start=2)

sample = space.sample()
sample_mask = space.sample(mask=np.array([0, 1, 0, 0, 0], dtype=np.int8))
print(f"{sample=}, {sample_mask=}")
print(f"{type(sample)=}, {type(sample_mask)=}")
assert isinstance(sample, sample_dtype), type(sample)
assert isinstance(sample_mask, sample_dtype), type(sample_mask)


@pytest.mark.parametrize(
"dtype",
[
str,
np.float32,
np.complex64,
],
)
def test_dtype_error(dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also Test dtype equals none

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey thanks for the check, I didn't bother to look if None was covered, it failed for None and had to raise error the way. Now I think of it, we aren't doing any checks for multi discrete dtype. I might add it or maybe create a issue as good first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, feel free to create a PR for multi-discrete testing next.

with pytest.raises(ValueError, match="Invalid Discrete dtype"):
Discrete(4, dtype=dtype)
Loading