Skip to content
This repository was archived by the owner on Jan 27, 2023. It is now read-only.

Commit 0e0d51f

Browse files
authored
Merge pull request #32 from kngwyu/cli-override
Add override option to CLI
2 parents 8a3a3e0 + 30fa4bf commit 0e0d51f

19 files changed

+122
-78
lines changed

README.md

+24
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,30 @@ E.g., if you want to use two hosts(localhost and anotherhost) and run `ppo_atari
6767
horovodrun -np 2 -H localhost:1,anotherhost:1 pipenv run python examples/ppo_atari.py train
6868
```
6969

70+
## Override configuration from CLI
71+
Currently, Rainy provides an easy-to-use CLI via [click](https://palletsprojects.com/p/click/).
72+
You can view its usages by, say,
73+
```bash
74+
pipenv run python examples/a2c_cart_pole.py --help
75+
```
76+
77+
This CLI has a simple data-driven interface.
78+
I.e., once you fill a config object, then all commands(train, eval, retarain, and etc.) work.
79+
So you can start experiments easily without copying and pasting, say, argument parser codes.
80+
81+
However, it has a limitation that you cannot add new options.
82+
83+
So Rainy-CLI provides an option named `override`, which executes the given string as a Python code
84+
with the config object set as `config`.
85+
86+
Example usage:
87+
```bash
88+
pipenv run python examples/a2c_cart_pole.py --override='config.grad_clip=0.5; config.nsteps=10' train
89+
```
90+
91+
If this feature still doesn't satisfy your requirement, you can
92+
[override subcommands by `ctx.invoke`](https://click.palletsprojects.com/en/7.x/advanced/#invoking-other-commands).
93+
7094
## Implementation Status
7195

7296
|**Algorithm** |**Multi Worker(Sync)**|**Recurrent** |**Discrete Action** |**Continuous Action**|**MPI** |

examples/a2c_atari.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from torch.optim import RMSprop
1010

1111

12-
def config() -> Config:
12+
def config(game: str = 'Breakout') -> Config:
1313
c = Config()
14-
c.set_env(lambda: Atari('Breakout', frame_stack=False))
14+
c.set_env(lambda: Atari(game, frame_stack=False))
1515
c.set_optimizer(
1616
lambda params: RMSprop(params, lr=7e-4, alpha=0.99, eps=1e-5)
1717
)
@@ -24,7 +24,7 @@ def config() -> Config:
2424
c.value_loss_weight = 1.0
2525
c.use_gae = False
2626
c.max_steps = int(2e7)
27-
c.eval_env = Atari('Breakout')
27+
c.eval_env = Atari(game)
2828
c.use_reward_monitor = True
2929
c.eval_deterministic = False
3030
c.episode_log_freq = 100
@@ -34,4 +34,4 @@ def config() -> Config:
3434

3535

3636
if __name__ == '__main__':
37-
cli.run_cli(config(), A2cAgent, script_path=os.path.realpath(__file__))
37+
cli.run_cli(config, A2cAgent, script_path=os.path.realpath(__file__))

examples/a2c_cart_pole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ def config() -> rainy.Config:
2323

2424

2525
if __name__ == '__main__':
26-
run_cli(config(), rainy.agents.A2cAgent, script_path=os.path.realpath(__file__))
26+
run_cli(config, rainy.agents.A2cAgent, script_path=os.path.realpath(__file__))

examples/a2c_hopper.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from torch.optim import Adam
88

99

10-
def config() -> Config:
10+
def config(envname: str = 'Hopper') -> Config:
1111
c = Config()
12-
c.set_env(lambda: PyBullet('Hopper'))
12+
c.set_env(lambda: PyBullet(envname))
1313
c.set_net_fn('actor-critic', net.actor_critic.fc_shared(policy=SeparateStdGaussianDist))
1414
c.set_parallel_env(pybullet_parallel())
1515
c.max_steps = int(1e6)
@@ -27,4 +27,4 @@ def config() -> Config:
2727

2828

2929
if __name__ == '__main__':
30-
cli.run_cli(config(), A2cAgent, script_path=os.path.realpath(__file__))
30+
cli.run_cli(config, A2cAgent, script_path=os.path.realpath(__file__))

examples/acktr_atari.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
}
1313

1414

15-
def config() -> Config:
15+
def config(game: str = 'Breakout') -> Config:
1616
c = Config()
17-
c.set_env(lambda: Atari('Breakout', frame_stack=False))
17+
c.set_env(lambda: Atari(game, frame_stack=False))
1818
c.set_optimizer(kfac.default_sgd(eta_max=0.2))
1919
c.set_preconditioner(lambda net: kfac.KfacPreConditioner(net, **KFAC_KWARGS))
2020
c.set_net_fn('actor-critic', net.actor_critic.ac_conv())
@@ -25,7 +25,7 @@ def config() -> Config:
2525
c.use_gae = True
2626
c.lr_min = 0.0
2727
c.max_steps = int(2e7)
28-
c.eval_env = Atari('Breakout')
28+
c.eval_env = Atari(game)
2929
c.eval_freq = None
3030
c.episode_log_freq = 100
3131
c.use_reward_monitor = True
@@ -34,4 +34,4 @@ def config() -> Config:
3434

3535

3636
if __name__ == '__main__':
37-
cli.run_cli(config(), AcktrAgent, script_path=os.path.realpath(__file__))
37+
cli.run_cli(config, AcktrAgent, script_path=os.path.realpath(__file__))

examples/acktr_cart_pole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ def config() -> Config:
3131

3232

3333
if __name__ == '__main__':
34-
cli.run_cli(config(), AcktrAgent, script_path=os.path.realpath(__file__))
34+
cli.run_cli(config, AcktrAgent, script_path=os.path.realpath(__file__))

examples/acktr_hopper.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
}
1515

1616

17-
def config() -> Config:
17+
def config(envname: str = 'Hopper') -> Config:
1818
c = Config()
1919
c.max_steps = int(4e5)
2020
c.nworkers = 12
2121
c.nsteps = 20
22-
c.set_env(lambda: PyBullet('Hopper'))
22+
c.set_env(lambda: PyBullet(envname))
2323
c.set_net_fn('actor-critic', net.actor_critic.fc_shared(policy=SeparateStdGaussianDist))
2424
c.set_parallel_env(pybullet_parallel())
2525
c.set_optimizer(kfac.default_sgd(eta_max=0.1))
@@ -34,4 +34,4 @@ def config() -> Config:
3434

3535

3636
if __name__ == '__main__':
37-
cli.run_cli(config(), AcktrAgent, script_path=os.path.realpath(__file__))
37+
cli.run_cli(config, AcktrAgent, script_path=os.path.realpath(__file__))

examples/aoc_atari.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from torch.optim import RMSprop
1010

1111

12-
def config() -> Config:
12+
def config(game: str = 'Breakout') -> Config:
1313
c = Config()
14-
c.set_env(lambda: Atari('Breakout', frame_stack=False))
14+
c.set_env(lambda: Atari(game, frame_stack=False))
1515
c.set_optimizer(
1616
lambda params: RMSprop(params, lr=7e-4, alpha=0.99, eps=1e-5)
1717
)
@@ -24,7 +24,7 @@ def config() -> Config:
2424
c.value_loss_weight = 1.0
2525
c.use_gae = False
2626
c.max_steps = int(2e7)
27-
c.eval_env = Atari('Breakout')
27+
c.eval_env = Atari(game)
2828
c.use_reward_monitor = True
2929
c.eval_deterministic = False
3030
c.episode_log_freq = 100
@@ -36,4 +36,4 @@ def config() -> Config:
3636

3737

3838
if __name__ == '__main__':
39-
cli.run_cli(config(), AocAgent, script_path=os.path.realpath(__file__))
39+
cli.run_cli(config, AocAgent, script_path=os.path.realpath(__file__))

examples/aoc_cart_pole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ def config() -> rainy.Config:
2323

2424

2525
if __name__ == '__main__':
26-
run_cli(config(), rainy.agents.AocAgent, script_path=os.path.realpath(__file__))
26+
run_cli(config, rainy.agents.AocAgent, script_path=os.path.realpath(__file__))

examples/ddqn_atari.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from torch.optim import RMSprop
88

99

10-
def config() -> Config:
10+
def config(game: str = 'Breakout') -> Config:
1111
c = Config()
12-
c.set_env(lambda: Atari('Breakout'))
12+
c.set_env(lambda: Atari(game))
1313
c.set_optimizer(
1414
lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
1515
)
@@ -20,12 +20,12 @@ def config() -> Config:
2020
c.train_start = 50000
2121
c.sync_freq = 10000
2222
c.max_steps = int(2e7)
23-
c.eval_env = Atari('Breakout', episodic_life=False)
23+
c.eval_env = Atari(game, episodic_life=False)
2424
c.eval_freq = None
2525
c.seed = 1
2626
c.use_reward_monitor = True
2727
return c
2828

2929

3030
if __name__ == '__main__':
31-
cli.run_cli(config(), DoubleDqnAgent, script_path=os.path.realpath(__file__))
31+
cli.run_cli(config, DoubleDqnAgent, script_path=os.path.realpath(__file__))

examples/ddqn_cart_pole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def config() -> Config:
1111

1212

1313
if __name__ == '__main__':
14-
cli.run_cli(config(), DoubleDqnAgent, script_path=os.path.realpath(__file__))
14+
cli.run_cli(config, DoubleDqnAgent, script_path=os.path.realpath(__file__))

examples/dqn_atari.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from torch.optim import RMSprop
88

99

10-
def config() -> Config:
10+
def config(game: str = 'Breakout') -> Config:
1111
c = Config()
12-
c.set_env(lambda: Atari('Breakout'))
12+
c.set_env(lambda: Atari(game))
1313
c.set_optimizer(
1414
lambda params: RMSprop(params, lr=0.00025, alpha=0.95, eps=0.01, centered=True)
1515
)
@@ -20,12 +20,12 @@ def config() -> Config:
2020
c.train_start = 50000
2121
c.sync_freq = 10000
2222
c.max_steps = int(2e7)
23-
c.eval_env = Atari('Breakout')
23+
c.eval_env = Atari(game)
2424
c.eval_freq = None
2525
c.use_reward_monitor = True
2626
return c
2727

2828

2929
if __name__ == '__main__':
30-
cli.run_cli(config(), DqnAgent, script_path=os.path.realpath(__file__))
30+
cli.run_cli(config, DqnAgent, script_path=os.path.realpath(__file__))
3131

examples/dqn_cart_pole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def config() -> Config:
1111

1212

1313
if __name__ == '__main__':
14-
cli.run_cli(config(), DqnAgent, script_path=os.path.realpath(__file__))
14+
cli.run_cli(config, DqnAgent, script_path=os.path.realpath(__file__))

examples/ppo_atari.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from torch.optim import Adam
77

88

9-
def config() -> Config:
9+
def config(game: str = 'Breakout') -> Config:
1010
c = Config()
11-
c.set_env(lambda: Atari('Breakout', frame_stack=False))
11+
c.set_env(lambda: Atari(game, frame_stack=False))
1212
# c.set_net_fn('actor-critic', net.actor_critic.ac_conv(rnn=net.GruBlock))
1313
c.set_net_fn('actor-critic', net.actor_critic.ac_conv())
1414
c.set_parallel_env(atari_parallel())
@@ -27,12 +27,12 @@ def config() -> Config:
2727
c.use_reward_monitor = True
2828
c.lr_min = None # set 0.0 if you decrease ppo_clip
2929
# eval settings
30-
c.eval_env = Atari('Breakout')
30+
c.eval_env = Atari(game)
3131
c.episode_log_freq = 100
3232
c.eval_freq = None
3333
c.save_freq = None
3434
return c
3535

3636

3737
if __name__ == '__main__':
38-
cli.run_cli(config(), PpoAgent, script_path=os.path.realpath(__file__))
38+
cli.run_cli(config, PpoAgent, script_path=os.path.realpath(__file__))

examples/ppo_cart_pole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ def config() -> rainy.Config:
2424

2525

2626
if __name__ == '__main__':
27-
run_cli(config(), rainy.agents.PpoAgent, script_path=os.path.realpath(__file__))
27+
run_cli(config, rainy.agents.PpoAgent, script_path=os.path.realpath(__file__))

examples/ppo_flicker_atari.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import rainy.utils.cli as cli
66

77

8-
def config() -> rainy.Config:
9-
c = ppo_atari.config()
10-
c.set_env(lambda: Atari('Breakout', flicker_frame=True, frame_stack=False))
8+
def config(game: str = 'Breakout') -> rainy.Config:
9+
c = ppo_atari.config(game)
10+
c.set_env(lambda: Atari(game, flicker_frame=True, frame_stack=False))
1111
c.set_parallel_env(atari_parallel(frame_stack=False))
1212
c.set_net_fn('actor-critic', rainy.net.actor_critic.ac_conv(rnn=rainy.net.GruBlock))
13-
c.eval_env = Atari('Breakout', flicker_frame=True, frame_stack=True)
13+
c.eval_env = Atari(game, flicker_frame=True, frame_stack=True)
1414
return c
1515

1616

1717
if __name__ == '__main__':
18-
cli.run_cli(config(), rainy.agents.PpoAgent, script_path=realpath(__file__))
18+
cli.run_cli(config, rainy.agents.PpoAgent, script_path=realpath(__file__))

examples/ppo_halfcheetah.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from torch.optim import Adam
88

99

10-
def config() -> Config:
10+
def config(envname: str = 'HalfCheetah') -> Config:
1111
c = Config()
12-
c.set_env(lambda: PyBullet('HalfCheetah'))
12+
c.set_env(lambda: PyBullet(envname))
1313
c.set_net_fn('actor-critic', net.actor_critic.fc_shared(policy=SeparateStdGaussianDist))
1414
c.set_parallel_env(pybullet_parallel())
1515
c.set_optimizer(lambda params: Adam(params, lr=3.0e-4, eps=1.0e-4))
@@ -30,4 +30,4 @@ def config() -> Config:
3030

3131

3232
if __name__ == '__main__':
33-
cli.run_cli(config(), PpoAgent, script_path=os.path.realpath(__file__))
33+
cli.run_cli(config, PpoAgent, script_path=os.path.realpath(__file__))

examples/ppo_hopper.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from torch.optim import Adam
88

99

10-
def config() -> Config:
10+
def config(envname: str = 'Hopper') -> Config:
1111
c = Config()
12-
c.set_env(lambda: PyBullet('Hopper'))
12+
c.set_env(lambda: PyBullet(envname))
1313
c.set_net_fn('actor-critic', net.actor_critic.fc_shared(policy=SeparateStdGaussianDist))
14-
c.set_parallel_env(pybullet_parallel(normalize_obs=True,normalize_reward=True))
14+
c.set_parallel_env(pybullet_parallel(normalize_obs=True, normalize_reward=True))
1515
c.set_optimizer(lambda params: Adam(params, lr=3.0e-4, eps=1.0e-4))
1616
c.max_steps = int(2e6)
1717
c.grad_clip = 0.5
@@ -30,4 +30,4 @@ def config() -> Config:
3030

3131

3232
if __name__ == '__main__':
33-
cli.run_cli(config(), PpoAgent, script_path=os.path.realpath(__file__))
33+
cli.run_cli(config, PpoAgent, script_path=os.path.realpath(__file__))

0 commit comments

Comments
 (0)