Skip to content

Commit

Permalink
Remove CEMRL
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Mar 23, 2020
1 parent b96a081 commit dcb54b5
Show file tree
Hide file tree
Showing 16 changed files with 38 additions and 531 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ NOTE: Python 3.6 is required!
## Implemented Algorithms

- A2C
- CEM-RL (with TD3)
- PPO
- SAC
- TD3
Expand Down
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ RL Baselines zoo also offers a simple interface to train, evaluate agents and do

modules/base
modules/a2c
modules/cem_rl
modules/ppo
modules/sac
modules/td3
Expand Down
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Pre-Release 0.4.0a0 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed CEMRL

New Features:
^^^^^^^^^^^^^
Expand Down
96 changes: 0 additions & 96 deletions docs/modules/cem_rl.rst

This file was deleted.

11 changes: 3 additions & 8 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,17 @@
import pytest
import gym

from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
EveryNTimesteps, StopTrainingOnRewardThreshold)


@pytest.mark.parametrize("model_class", [A2C, CEMRL, PPO, SAC, TD3])
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
def test_callbacks(model_class):
log_folder = './logs/callbacks/'
kwargs = {}
if model_class == CEMRL:
kwargs['pop_size'] = 2
kwargs['n_grad'] = 1

# Create RL model
# Small network for fast test
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]), **kwargs)
model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[32]))

checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import gym
import pytest

from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.vec_env import DummyVecEnv

MODEL_LIST = [
CEMRL,
PPO,
A2C,
TD3,
Expand Down
8 changes: 1 addition & 7 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
Expand All @@ -14,12 +14,6 @@ def test_td3(action_noise):
model.learn(total_timesteps=1000, eval_freq=500)


def test_cemrl():
model = CEMRL('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[16]), pop_size=2, n_grad=1,
learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise)
model.learn(total_timesteps=1000, eval_freq=500)


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
def test_onpolicy(model_class, env_id):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import torch as th
from copy import deepcopy

from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.identity_env import IdentityEnvBox
from torchy_baselines.common.vec_env import DummyVecEnv

MODEL_LIST = [
CEMRL,
PPO,
A2C,
TD3,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchy_baselines.common.running_mean_std import RunningMeanStd
from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize, VecFrameStack, sync_envs_normalization, unwrap_vec_normalize
from torchy_baselines import CEMRL, SAC, TD3
from torchy_baselines import SAC, TD3

ENV_ID = 'Pendulum-v0'

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_normalize_external():
assert np.all(norm_rewards < 1)


@pytest.mark.parametrize("model_class", [SAC, TD3, CEMRL])
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_offpolicy_normalization(model_class):
env = DummyVecEnv([make_env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
Expand Down
1 change: 0 additions & 1 deletion torchy_baselines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

from torchy_baselines.a2c import A2C
from torchy_baselines.cem_rl import CEMRL
from torchy_baselines.ppo import PPO
from torchy_baselines.sac import SAC
from torchy_baselines.td3 import TD3
Expand Down
2 changes: 0 additions & 2 deletions torchy_baselines/cem_rl/__init__.py

This file was deleted.

132 changes: 0 additions & 132 deletions torchy_baselines/cem_rl/cem.py

This file was deleted.

Loading

0 comments on commit dcb54b5

Please sign in to comment.