Skip to content

Commit

Permalink
Add descriptions to some asserts in the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mwydmuch committed Jan 11, 2024
1 parent a59629b commit f5b9ae7
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 39 deletions.
6 changes: 4 additions & 2 deletions examples/python/record_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@
game.replay_episode(f"episode{i}_rec.lmp")

while not game.is_episode_finished():
# Get a state
s = game.get_state()

# Use advance_action instead of make_action.
# Use advance_action instead of make_action to proceed
game.advance_action()

r = game.get_last_reward()
# Retrieve the last actions and the reward
a = game.get_last_action()
r = game.get_last_reward()

print(f"State #{s.number}")
print("Action:", a)
Expand Down
139 changes: 139 additions & 0 deletions tests/manual_test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/usr/bin/env python3ch

# Tests ViZDoom seed option.
# This test can be run as Python script or via PyTest

import itertools
import os
import random

import cv2
import numpy as np

import vizdoom as vzd


def test_seed(repeats=10, tics=8, audio_buffer=False, seed=1993):
scenarios_to_skip = [
# "deadly_corridor.cfg",
# "defend_the_center.cfg",
# "deathmatch.cfg",
# "health_gathering.cfg",
# "health_gathering_supreme.cfg",
# "deathmatch.cfg",
# Multiplayer scenarios
"cig.cfg",
"multi_duel.cfg",
"multi.cfg",
"oblige.cfg",
]
configs = [
file
for file in os.listdir(vzd.scenarios_path)
if file.endswith(".cfg") and file not in scenarios_to_skip
]
print(configs)
game = vzd.DoomGame()

for config in configs:
print(config)
initial_states = []
states_after_action = []

game = vzd.DoomGame()
game.load_config(config)
game.set_window_visible(False)

# Creates all possible actions depending on how many buttons there are.
actions_num = game.get_available_buttons_size()
actions = []
for perm in itertools.product([False, True], repeat=actions_num):
actions.append(list(perm))

# Enable all buffers
buffers = ["screen_buffer", "depth_buffer", "labels_buffer", "automap_buffer"]
game.set_depth_buffer_enabled(True)
game.set_labels_buffer_enabled(True)
game.set_automap_buffer_enabled(True)
game.set_objects_info_enabled(True)
game.set_sectors_info_enabled(True)
game.set_audio_buffer_enabled(audio_buffer)
if audio_buffer:
buffers.append("audio_buffer")

game.set_screen_format(vzd.ScreenFormat.BGR24)

game.init()

for i in range(repeats):
game.set_seed(1993)
random.seed(seed)
# game.init()
game.new_episode()

initial_states.append(game.get_state())
if i % 2 == 0:
game.make_action(random.choice(actions), tics=tics)
else:
action = random.choice(actions)
for _ in range(tics):
game.make_action(action, tics=1)

game.make_action(random.choice(actions), tics=tics)
states_after_action.append(game.get_state())

# game.close()

for s1, s2 in zip(initial_states[:-1], initial_states[1:]):
assert s1.tic == s2.tic
assert np.array_equal(s1.game_variables, s2.game_variables)

if not np.array_equal(s1.screen_buffer, s2.screen_buffer):
print("Initial states are not equal")
print(f"s1: {s1.tic}, {s1.game_variables}")
print(f"s2: {s2.tic}, {s2.game_variables}")
print(np.all(s1.screen_buffer == s2.screen_buffer))
print(np.array_equal(s1.screen_buffer, s2.screen_buffer))
cv2.imshow("s1", s1.screen_buffer)
cv2.imshow("s2", s2.screen_buffer)
cv2.imshow("s1 - s2", s1.screen_buffer - s2.screen_buffer)
cv2.waitKey(int(10000))

for b in buffers:
if not np.array_equal(getattr(s1, b), getattr(s2, b)):
print("Initial states are not equal")
cv2.imshow("s1", getattr(s1, b))
cv2.imshow("s2", getattr(s2, b))
cv2.imshow("s1 - s2", getattr(s1, b) - getattr(s2, b))
cv2.waitKey(int(10000))

# assert np.array_equal(getattr(s1, b), getattr(s2, b))

for s1, s2 in zip(states_after_action[:-1], states_after_action[1:]):
assert s1.tic == s2.tic
assert np.array_equal(s1.game_variables, s2.game_variables)

if not np.array_equal(s1.screen_buffer, s2.screen_buffer):
print("States after action are not equal")
print(f"s1: {s1.tic}, {s1.game_variables}")
print(f"s2: {s2.tic}, {s2.game_variables}")
print(np.all(s1.screen_buffer == s2.screen_buffer))
print(np.array_equal(s1.screen_buffer, s2.screen_buffer))
cv2.imshow("s1", s1.screen_buffer)
cv2.imshow("s2", s2.screen_buffer)
cv2.imshow("s1 - s2", s1.screen_buffer - s2.screen_buffer)
cv2.waitKey(int(10000))

for b in buffers:
if not np.array_equal(getattr(s1, b), getattr(s2, b)):
print("States after action are not equal")
cv2.imshow("s1", getattr(s1, b))
cv2.imshow("s2", getattr(s2, b))
cv2.imshow("s1 - s2", getattr(s1, b) - getattr(s2, b))
cv2.waitKey(int(10000))

# assert np.array_equal(getattr(s1, b), getattr(s2, b))


if __name__ == "__main__":
test_seed()
14 changes: 7 additions & 7 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,30 @@ def _test_enums(enum_name, func_name):
]
all_values_names = [v.name for v in all_values]

# set test
# set_X function test
set_func(all_values)
get_buttons_names = [v.name for v in get_func()]
assert all_values_names == get_buttons_names
get_values_names = [v.name for v in get_func()]
assert all_values_names == get_values_names

# add test
# add_X function test
clear_func()
for i, v in enumerate(all_values):
add_func(v)
get_values_names = [v.name for v in get_func()]
assert all_values_names[: i + 1] == get_values_names

# again set test
# Check if set function overwrites previous values
set_func(all_values)
get_values_names = [v.name for v in get_func()]
assert all_values_names == get_values_names

# multiple adds
# Multiple add_X functions test
for i, v in enumerate(all_values):
add_func(v)
get_values_names = [v.name for v in get_func()]
assert all_values_names == get_values_names

# multiple in set
# Test duplicated values in set_X function
set_func(all_values + all_values)
get_values_names = [v.name for v in get_func()]
assert all_values_names == get_values_names
Expand Down
5 changes: 3 additions & 2 deletions tests/test_game_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


def test_game_args():
print("Testing setting custom game arguments...")

game = vzd.DoomGame()
game.set_window_visible(False)
Expand All @@ -16,12 +17,12 @@ def test_game_args():
args_all = args1 + " " + args2

game.set_game_args(args_all)
assert game.get_game_args() == args_all
assert game.get_game_args() == args_all, "Game args not set correctly."

game.clear_game_args()
game.add_game_args(args1)
game.add_game_args(args2)
assert game.get_game_args() == args_all
assert game.get_game_args() == args_all, "Game args not set correctly."


if __name__ == "__main__":
Expand Down
64 changes: 36 additions & 28 deletions tests/test_get_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ def _test_get_state(
num_iterations=10,
num_states=20,
mem_eta_mb=0,
depthBuffer=False,
labelsBuffer=False,
automapBuffer=False,
objectsInfo=False,
sectorsInfo=False,
audioBuffer=False,
depth_buffer=False,
labels_buffer=False,
automap_buffer=False,
objects_info=False,
sectors_info=False,
audio_buffer=False,
seed=1993,
):
print("Testing get_state() ...")

random.seed(1993)
random.seed(seed)

buttons = [
vzd.Button.MOVE_FORWARD,
Expand All @@ -46,21 +47,21 @@ def _test_get_state(
game.set_episode_timeout(num_states)
game.set_available_buttons(buttons)

game.set_depth_buffer_enabled(depthBuffer)
game.set_labels_buffer_enabled(labelsBuffer)
game.set_automap_buffer_enabled(automapBuffer)
game.set_objects_info_enabled(objectsInfo)
game.set_sectors_info_enabled(sectorsInfo)
game.set_audio_buffer_enabled(audioBuffer)
game.set_depth_buffer_enabled(depth_buffer)
game.set_labels_buffer_enabled(labels_buffer)
game.set_automap_buffer_enabled(automap_buffer)
game.set_objects_info_enabled(objects_info)
game.set_sectors_info_enabled(sectors_info)
game.set_audio_buffer_enabled(audio_buffer)

buffers = ["screen_buffer"]
if depthBuffer:
if depth_buffer:
buffers.append("depth_buffer")
if labelsBuffer:
if labels_buffer:
buffers.append("labels_buffer")
if automapBuffer:
if automap_buffer:
buffers.append("automap_buffer")
if audioBuffer:
if audio_buffer:
buffers.append("audio_buffer")
# This fixes "BiquadFilter_setParams: Assertion `gain > 0.00001f' failed" issue
# or "no audio in buffer" issue caused by a bug in OpenAL version 1.19.
Expand Down Expand Up @@ -97,13 +98,16 @@ def _test_get_state(
max_vals = {b: -np.inf for b in buffers}
for s, bs_copy in zip(states, buffers_copies):
for b in buffers:
assert np.array_equal(getattr(s, b), bs_copy[b])
assert np.array_equal(
getattr(s, b), bs_copy[b]
), f"Buffer {b} is not equal with its copy"
min_vals[b] = min(min_vals[b], np.min(bs_copy[b]))
max_vals[b] = max(max_vals[b], np.max(bs_copy[b]))

for b in buffers:
print(f"Buffer {b} min: {min_vals[b]}, max: {max_vals[b]}")
assert min_vals[b] != max_vals[b]
assert (
min_vals[b] != max_vals[b]
), f"Buffer {b} min: {min_vals[b]}, max: {max_vals[b]} are equal, buffer is empty"

# Save and load states via pickle - confirms that states and all sub-objects (labels, lines, objects) are picklable.
with open("tmp_states.pkl", "wb") as f:
Expand All @@ -114,7 +118,9 @@ def _test_get_state(

# Compare loaded states with their copies - to confirm that pickling doesn't mutate states.
for s, s_copy in zip(states, pickled_states):
assert pickle.dumps(s) == pickle.dumps(s_copy)
assert pickle.dumps(s) == pickle.dumps(
s_copy
), "Pickled state is not equal with its original object after save and load"

del pickled_states
os.remove("tmp_states.pkl")
Expand All @@ -133,7 +139,9 @@ def _test_get_state(
prev_mem = mem
prev_len = len(states)
elif prev_len == len(states):
assert abs(prev_mem - mem) < mem_eta_mb
assert (
abs(prev_mem - mem) < mem_eta_mb
), f"Memory leak detected: with {len(states)} states saved, after episode {i + 1} / {num_iterations}: {mem} MB used, expected ~{prev_mem} +/- {mem_eta_mb} MB"


def test_get_state(num_iterations=10, num_states=20):
Expand All @@ -142,12 +150,12 @@ def test_get_state(num_iterations=10, num_states=20):
num_iterations=num_iterations,
num_states=num_states,
mem_eta_mb=0,
depthBuffer=True,
labelsBuffer=True,
automapBuffer=True,
objectsInfo=True,
sectorsInfo=True,
audioBuffer=False, # Turned off by default, because it fails on some systems without audio backend and OpenAL installed
depth_buffer=True,
labels_buffer=True,
automap_buffer=True,
objects_info=True,
sectors_info=True,
audio_buffer=False, # Turned off by default, because it fails on some systems without audio backend and OpenAL installed
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_labels_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def check_label(labels_buffer, label):


def test_labels_buffer():
print("Testing labels buffer ...")
game = vzd.DoomGame()
game.load_config(os.path.join(vzd.scenarios_path, "deathmatch.cfg"))

Expand Down

0 comments on commit f5b9ae7

Please sign in to comment.