Skip to content

Commit

Permalink
Fixed couple of errors and added new tests (#380)
Browse files Browse the repository at this point in the history
* minor error fixes

* Fixed Evaluate error and Action/Obs shape errors

* code climate fix

* Modularised tests and added few new ones
  • Loading branch information
sampreet-arthi authored Oct 12, 2020
1 parent a2c8c7e commit 25eb018
Show file tree
Hide file tree
Showing 20 changed files with 377 additions and 319 deletions.
2 changes: 1 addition & 1 deletion genrl/core/rollout_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def add(
log_prob = log_prob.reshape(-1, 1)

self.observations[self.pos] = obs.detach().clone()
self.actions[self.pos] = action.squeeze().detach().clone()
self.actions[self.pos] = action.detach().clone()
self.rewards[self.pos] = reward.detach().clone()
self.dones[self.pos] = done.detach().clone()
self.values[self.pos] = value.detach().clone().flatten()
Expand Down
18 changes: 13 additions & 5 deletions genrl/environments/vec_env/vector_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@ def __init__(self, envs: List, n_envs: int = 2):
self.envs = envs
self.env = envs[0]
self._n_envs = n_envs

self.episode_reward = torch.zeros(self.n_envs)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space

self.episode_reward = torch.zeros(self.n_envs)

def __getattr__(self, name: str) -> Any:
env = super(VecEnv, self).__getattribute__("env")
return getattr(env, name)
Expand Down Expand Up @@ -119,12 +117,22 @@ def action_spaces(self):

@property
def obs_shape(self):
obs_shape = self.observation_space.shape
if isinstance(self.observation_space, gym.spaces.Discrete):
obs_shape = (1,)
elif isinstance(self.observation_space, gym.spaces.Box):
obs_shape = self.observation_space.shape
else:
raise NotImplementedError
return obs_shape

@property
def action_shape(self):
action_shape = self.action_space.shape
if isinstance(self.action_space, gym.spaces.Box):
action_shape = self.action_space.shape
elif isinstance(self.action_space, gym.spaces.Discrete):
action_shape = (1,)
else:
raise NotImplementedError
return action_shape


Expand Down
5 changes: 3 additions & 2 deletions genrl/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
load_weights: str = None,
load_hyperparams: str = None,
render: bool = False,
evaluate_episodes: int = 50,
evaluate_episodes: int = 25,
seed: Optional[int] = None,
):
self.agent = agent
Expand Down Expand Up @@ -111,8 +111,9 @@ def evaluate(self, render: bool = False) -> None:
for i, di in enumerate(done):
if di:
episode += 1
episode_rewards.append(episode_reward[i])
episode_rewards.append(episode_reward[i].clone().detach())
episode_reward[i] = 0
self.env.reset_single_env(i)
if episode == self.evaluate_episodes:
print(
"Evaluated for {} episodes, Mean Reward: {:.2f}, Std Deviation for the Reward: {:.2f}".format(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_agents/test_deep/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from tests.test_agents.test_deep.test_a2c import test_a2c, test_a2c_cnn
from tests.test_agents.test_deep.test_ddpg import test_ddpg
from tests.test_agents.test_deep.test_a2c import TestA2C
from tests.test_agents.test_deep.test_ddpg import TestDDPG
from tests.test_agents.test_deep.test_dqn import TestDQN
from tests.test_agents.test_deep.test_dqn_cnn import TestDQNCNN
from tests.test_agents.test_deep.test_ppo1 import test_ppo1, test_ppo1_cnn
from tests.test_agents.test_deep.test_sac import test_sac
from tests.test_agents.test_deep.test_td3 import test_td3
from tests.test_agents.test_deep.test_vpg import test_vpg, test_vpg_cnn
from tests.test_agents.test_deep.test_ppo1 import TestPPO
from tests.test_agents.test_deep.test_sac import TestSAC
from tests.test_agents.test_deep.test_td3 import TestTD3
from tests.test_agents.test_deep.test_vpg import TestVPG
54 changes: 35 additions & 19 deletions tests/test_agents/test_deep/test_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,41 @@
from genrl.trainers import OnPolicyTrainer


def test_a2c():
env = VectorEnv("CartPole-v0", 1)
algo = A2C("mlp", env, rollout_size=128)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
class TestA2C:
def test_a2c_discrete(self):
env = VectorEnv("CartPole-v0", 1)
algo = A2C("mlp", env, rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

def test_a2c_continuous(self):
env = VectorEnv("Pendulum-v0", 1)
algo = A2C("mlp", env, rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

def test_a2c_cnn():
env = VectorEnv("Pong-v0", 1, env_type="atari")
algo = A2C("cnn", env, rollout_size=128)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
def test_a2c_cnn(self):
env = VectorEnv("Pong-v0", 1, env_type="atari")
algo = A2C("cnn", env, rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")


def test_a2c_shared():
env = VectorEnv("CartPole-v0", 1)
algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
def test_a2c_shared_discrete(self):
env = VectorEnv("CartPole-v0", 1)
algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")
42 changes: 23 additions & 19 deletions tests/test_agents/test_deep/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,31 @@ def get_params(self):
return actor_params, critic_params


def test_custom_vpg():
env = VectorEnv("CartPole-v0", 1)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy = custom_policy(state_dim, action_dim)
class TestCustomAgents:
def test_custom_vpg(self):
env = VectorEnv("CartPole-v0", 1)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy = custom_policy(state_dim, action_dim)

algo = VPG(policy, env)

trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
algo = VPG(policy, env)

trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")

def test_custom_ppo1():
env = VectorEnv("CartPole-v0", 1)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
actorcritic = custom_actorcritic(state_dim, action_dim)
def test_custom_ppo1(self):
env = VectorEnv("CartPole-v0", 1)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
actorcritic = custom_actorcritic(state_dim, action_dim)

algo = PPO1(actorcritic, env)
algo = PPO1(actorcritic, env)

trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")
92 changes: 46 additions & 46 deletions tests/test_agents/test_deep/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,52 @@
from genrl.trainers import OffPolicyTrainer


def test_ddpg():
env = VectorEnv("Pendulum-v0", 2)
algo = DDPG(
"mlp",
env,
batch_size=5,
noise=NormalActionNoise,
policy_layers=[1, 1],
value_layers=[1, 1],
)
class TestDDPG:
def test_ddpg(self):
env = VectorEnv("Pendulum-v0", 2)
algo = DDPG(
"mlp",
env,
batch_size=5,
noise=NormalActionNoise,
policy_layers=[1, 1],
value_layers=[1, 1],
)

trainer = OffPolicyTrainer(
algo,
env,
log_mode=["csv"],
logdir="./logs",
epochs=4,
max_ep_len=200,
warmup_steps=10,
start_update=10,
)
trainer.train()
shutil.rmtree("./logs")
trainer = OffPolicyTrainer(
algo,
env,
log_mode=["csv"],
logdir="./logs",
epochs=4,
max_ep_len=200,
warmup_steps=10,
start_update=10,
)
trainer.train()
shutil.rmtree("./logs")

def test_ddpg_shared(self):
env = VectorEnv("Pendulum-v0", 2)
algo = DDPG(
"mlp",
env,
batch_size=5,
noise=NormalActionNoise,
shared_layers=[1, 1],
policy_layers=[1, 1],
value_layers=[1, 1],
)

def test_ddpg_shared():
env = VectorEnv("Pendulum-v0", 2)
algo = DDPG(
"mlp",
env,
batch_size=5,
noise=NormalActionNoise,
shared_layers=[1, 1],
policy_layers=[1, 1],
value_layers=[1, 1],
)

trainer = OffPolicyTrainer(
algo,
env,
log_mode=["csv"],
logdir="./logs",
epochs=4,
max_ep_len=200,
warmup_steps=10,
start_update=10,
)
trainer.train()
shutil.rmtree("./logs")
trainer = OffPolicyTrainer(
algo,
env,
log_mode=["csv"],
logdir="./logs",
epochs=4,
max_ep_len=200,
warmup_steps=10,
start_update=10,
)
trainer.train()
shutil.rmtree("./logs")
3 changes: 3 additions & 0 deletions tests/test_agents/test_deep/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_vanilla_dqn(self):
start_update=10,
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

def test_double_dqn(self):
Expand Down Expand Up @@ -93,6 +94,7 @@ def test_prioritized_dqn(self):
log_interval=1,
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

def test_noisy_dqn(self):
Expand Down Expand Up @@ -133,4 +135,5 @@ def test_categorical_dqn(self):
start_update=10,
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")
2 changes: 2 additions & 0 deletions tests/test_agents/test_deep/test_dqn_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_double_dqn(self):
start_update=10,
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

def test_dueling_dqn(self):
Expand All @@ -66,6 +67,7 @@ def test_dueling_dqn(self):
start_update=10,
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

def test_prioritized_dqn(self):
Expand Down
52 changes: 33 additions & 19 deletions tests/test_agents/test_deep/test_ppo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,39 @@
from genrl.trainers import OnPolicyTrainer


def test_ppo1():
env = VectorEnv("CartPole-v0")
algo = PPO1("mlp", env, rollout_size=128)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
class TestPPO:
def test_ppo1(self):
env = VectorEnv("CartPole-v0")
algo = PPO1("mlp", env, rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")

def test_ppo1_continuous(self):
env = VectorEnv("Pendulum-v0")
algo = PPO1("mlp", env, rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")

def test_ppo1_cnn():
env = VectorEnv("Pong-v0", env_type="atari")
algo = PPO1("cnn", env, rollout_size=128)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
def test_ppo1_cnn(self):
env = VectorEnv("Pong-v0", env_type="atari")
algo = PPO1("cnn", env, rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")


def test_ppo1_shared():
env = VectorEnv("CartPole-v0")
algo = PPO1("mlp", env, shared_layers=[32, 32], rollout_size=128)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()
shutil.rmtree("./logs")
def test_ppo1_shared(self):
env = VectorEnv("CartPole-v0")
algo = PPO1("mlp", env, shared_layers=[32, 32], rollout_size=128)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")
Loading

0 comments on commit 25eb018

Please sign in to comment.