-
Notifications
You must be signed in to change notification settings - Fork 371
/
cartpole_c51_deploy.py
33 lines (28 loc) · 1.25 KB
/
cartpole_c51_deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import gym
import torch
from easydict import EasyDict
from ding.config import compile_config
from ding.envs import DingEnvWrapper
from ding.policy import C51Policy, single_env_forward_wrapper
from ding.model import C51DQN
from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config
def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str):
main_config.exp_name = 'cartpole_c51_deploy'
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
env = DingEnvWrapper(gym.make('CartPole-v0'), EasyDict(env_wrapper='default'))
model = C51DQN(**cfg.policy.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
policy = C51Policy(cfg.policy, model=model).eval_mode
forward_fn = single_env_forward_wrapper(policy.forward)
obs = env.reset()
returns = 0.
while True:
action = forward_fn(obs)
obs, rew, done, info = env.step(action)
returns += rew
if done:
break
print(f'Deploy is finished, final epsiode return is: {returns}')
if __name__ == "__main__":
main(cartpole_c51_config, cartpole_c51_create_config, 'cartpole_c51_seed0/ckpt/ckpt_best.pth.tar')