-
-
Notifications
You must be signed in to change notification settings - Fork 839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Discrete(dtype)
parameter
#1197
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move the changes then good to go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot to request testing,
@pytest.mark.parametrize("dtype, sample_dtype", [
(int, np.int64),
(np.int64, np.int64),
(np.int32, np.int32),
(np.uint8, np.uint8),
])
def test_dtype(dtype, sample_dtype):
space = Discrete(n=5, dtype=dtype, start=2)
sample = space.sample()
sample_mask = space.sample(mask=np.array([0, 1, 0, 0, 0], dtype=np.int8))
print(f'{sample=}, {sample_mask=}')
print(f'{type(sample)=}, {type(sample_mask)=}')
assert isinstance(sample, sample_dtype), type(sample)
assert isinstance(sample_mask, sample_dtype), type(sample_mask)
@pytest.mark.parametrize("dtype", [
str,
np.float32,
np.complex64,
])
def test_dtype_error(dtype):
with pytest.raises(ValueError, match="Invalid Discrete dtype"):
Discrete(4, dtype=dtype)
Could you add this to tests/spaces/test_discrete.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have added the tests, thanks!
np.complex64, | ||
], | ||
) | ||
def test_dtype_error(dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also Test dtype equals none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey thanks for the check, I didn't bother to look if None was covered, it failed for None and had to raise error the way. Now I think of it, we aren't doing any checks for multi discrete dtype. I might add it or maybe create a issue as good first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, feel free to create a PR for multi-discrete testing next.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realised that we need to update the contains
to work with other dtype.
@JDRanpariya Could you update contains
and add to test_dtype
, checking that the samples are contained within the space.
It might be worth checking if failure cases as well including int
for np.int64
dtype and empty shape arrays
Hmm, there is no test for Would look into it, shortly. I would like to spend sometime looking at whole code base for a while. |
Description
Add support for dtype in Discrete Space.
Fixes # (issue)
#1118
Type of change
Please delete options that are not relevant.
Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)