Skip to content

Commit

Permalink
Merge pull request #17 from sisl/fix-tests
Browse files Browse the repository at this point in the history
Fix package tests that are failing
  • Loading branch information
cpondoc authored Jan 21, 2025
2 parents 30a5230 + b3930d3 commit 7c61e90
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tests/package_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def test_constructor():
env = gymnasium.make("pyrorl/PyroRL-v0", **kwargs)

# Make basic checks for the constructor
assert env.num_rows == num_rows
assert env.num_cols == num_cols
np.testing.assert_array_equal(env.populated_areas, populated_areas)
np.testing.assert_array_equal(env.paths, paths)
assert env.unwrapped.num_rows == num_rows
assert env.unwrapped.num_cols == num_cols
np.testing.assert_array_equal(env.unwrapped.populated_areas, populated_areas)
np.testing.assert_array_equal(env.unwrapped.paths, paths)

# Special check for paths to populated areas
for key in paths_to_pops:
np.testing.assert_array_equal(
np.array(env.paths_to_pops[key]),
np.array(env.unwrapped.paths_to_pops[key]),
np.array(paths_to_pops[key]),
)

Expand Down Expand Up @@ -112,15 +112,15 @@ def test_reset():

# Check that reset makes it all the same
env.reset()
assert env.num_rows == num_rows
assert env.num_cols == num_cols
np.testing.assert_array_equal(env.populated_areas, populated_areas)
np.testing.assert_array_equal(env.paths, paths)
assert env.unwrapped.num_rows == num_rows
assert env.unwrapped.num_cols == num_cols
np.testing.assert_array_equal(env.unwrapped.populated_areas, populated_areas)
np.testing.assert_array_equal(env.unwrapped.paths, paths)

# Special check for paths to populated areas
for key in paths_to_pops:
np.testing.assert_array_equal(
np.array(env.paths_to_pops[key]),
np.array(env.unwrapped.paths_to_pops[key]),
np.array(paths_to_pops[key]),
)

Expand Down Expand Up @@ -242,6 +242,6 @@ def test_generate_gif(mocker):
env.render()

# Generate the gif, check that it exists, and then remove it
env.generate_gif()
env.unwrapped.generate_gif()
assert os.path.exists("training.gif")
os.remove("training.gif")

0 comments on commit 7c61e90

Please sign in to comment.