Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Discrete(dtype) parameter #1197

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

JDRanpariya
Copy link
Contributor

Description

Add support for dtype in Discrete Space.

Fixes # (issue)
#1118

Type of change

Please delete options that are not relevant.

  • Documentation only change (no code changed)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have run the pre-commit checks with pre-commit run --all-files (see CONTRIBUTING.md instructions to set it up)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the changes then good to go

gymnasium/spaces/discrete.py Show resolved Hide resolved
gymnasium/spaces/discrete.py Outdated Show resolved Hide resolved
@pseudo-rnd-thoughts pseudo-rnd-thoughts changed the title #1118 Add Discrete(dtype) parameter Oct 6, 2024
Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to request testing,

@pytest.mark.parametrize("dtype, sample_dtype", [
    (int, np.int64),
    (np.int64, np.int64),
    (np.int32, np.int32),
    (np.uint8, np.uint8),
])
def test_dtype(dtype, sample_dtype):
    space = Discrete(n=5, dtype=dtype, start=2)

    sample = space.sample()
    sample_mask = space.sample(mask=np.array([0, 1, 0, 0, 0], dtype=np.int8))
    print(f'{sample=}, {sample_mask=}')
    print(f'{type(sample)=}, {type(sample_mask)=}')
    assert isinstance(sample, sample_dtype), type(sample)
    assert isinstance(sample_mask, sample_dtype), type(sample_mask)


@pytest.mark.parametrize("dtype", [
    str,
    np.float32,
    np.complex64,
])
def test_dtype_error(dtype):
    with pytest.raises(ValueError, match="Invalid Discrete dtype"):
        Discrete(4, dtype=dtype)

Could you add this to tests/spaces/test_discrete.py

Copy link
Contributor Author

@JDRanpariya JDRanpariya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have added the tests, thanks!

np.complex64,
],
)
def test_dtype_error(dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also Test dtype equals none

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey thanks for the check, I didn't bother to look if None was covered, it failed for None and had to raise error the way. Now I think of it, we aren't doing any checks for multi discrete dtype. I might add it or maybe create a issue as good first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, feel free to create a PR for multi-discrete testing next.

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised that we need to update the contains to work with other dtype.
@JDRanpariya Could you update contains and add to test_dtype, checking that the samples are contained within the space.
It might be worth checking if failure cases as well including int for np.int64 dtype and empty shape arrays

@JDRanpariya
Copy link
Contributor Author

JDRanpariya commented Oct 7, 2024

Hmm, there is no test for contains, to_jsonable and from_jsonable, would have to know the expected behavior and write tests for all three of them.

Would look into it, shortly. I would like to spend sometime looking at whole code base for a while.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants