-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(rjy): add mountain_car env and its muzero visualization (#181)
* env(rjy): add mtcar env * feature(rjy): add vis for mz+mtcar * feature(rjy): add tsne for mtcar * feature(rjy): polish vis for mz+mtcar * polish(rjy): clear redundancy info * fix(rjy): fix typo * polish(rjy): compressed file size --------- Co-authored-by: nighood <[email protected]> Co-authored-by: 蒲源 <[email protected]>
- Loading branch information
1 parent
8acf6cf
commit c9fccf0
Showing
9 changed files
with
946 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
92 changes: 92 additions & 0 deletions
92
zoo/classic_control/mountain_car/config/mtcar_muzero_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from easydict import EasyDict | ||
|
||
# ============================================================== | ||
# begin of the most frequently changed config specified by the user | ||
# ============================================================== | ||
collector_env_num = 8 | ||
n_episode = 8 | ||
evaluator_env_num = 3 | ||
num_simulations = 25 | ||
update_per_collect = 100 | ||
batch_size = 256 | ||
max_env_step = int(1e6) | ||
reanalyze_ratio = 0 | ||
# ============================================================== | ||
# end of the most frequently changed config specified by the user | ||
# ============================================================== | ||
|
||
mountain_car_muzero_config = dict( | ||
exp_name=f'mountain_car_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', | ||
env=dict( | ||
env_name='MountainCar-v0', | ||
continuous=False, | ||
manually_discretization=False, | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
n_evaluator_episode=evaluator_env_num, | ||
manager=dict(shared_memory=False, ), | ||
), | ||
policy=dict( | ||
eps=dict( | ||
# (bool) Whether to use eps greedy exploration in collecting data. | ||
eps_greedy_exploration_in_collect=True, | ||
# (str) The type of decaying epsilon. Options are 'linear', 'exp'. | ||
type='linear', | ||
# (float) The start value of eps. | ||
start=1., | ||
# (float) The end value of eps. | ||
end=0.05, | ||
# (int) The decay steps from start to end eps. | ||
decay=int(1e5), | ||
), | ||
model=dict( | ||
observation_shape=2, | ||
action_space_size=3, | ||
model_type='mlp', | ||
lstm_hidden_size=128, | ||
latent_state_dim=64, | ||
self_supervised_learning_loss=True, # NOTE: default is False. | ||
discrete_action_encoding_type='one_hot', | ||
norm_type='BN', | ||
), | ||
cuda=True, | ||
env_type='not_board_games', | ||
game_segment_length=50, | ||
update_per_collect=update_per_collect, | ||
batch_size=batch_size, | ||
optim_type='Adam', | ||
lr_piecewise_constant_decay=False, | ||
learning_rate=0.003, | ||
ssl_loss_weight=2, # NOTE: default is 0. | ||
num_simulations=num_simulations, | ||
reanalyze_ratio=reanalyze_ratio, | ||
n_episode=n_episode, | ||
eval_freq=int(2e2), | ||
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. | ||
collector_env_num=collector_env_num, | ||
evaluator_env_num=evaluator_env_num, | ||
), | ||
) | ||
|
||
mountain_car_muzero_config = EasyDict(mountain_car_muzero_config) | ||
main_config = mountain_car_muzero_config | ||
|
||
mountain_car_muzero_create_config = dict( | ||
env=dict( | ||
type='mountain_car_lightzero', | ||
import_names=['zoo.classic_control.mountain_car.envs.mtcar_lightzero_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict( | ||
type='muzero', | ||
import_names=['lzero.policy.muzero'], | ||
), | ||
) | ||
mountain_car_muzero_create_config = EasyDict(mountain_car_muzero_create_config) | ||
create_config = mountain_car_muzero_create_config | ||
|
||
if __name__ == "__main__": | ||
# Users can use different train entry by specifying the entry_type. | ||
from lzero.entry import train_muzero | ||
|
||
train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) |
Empty file.
56 changes: 56 additions & 0 deletions
56
zoo/classic_control/mountain_car/entry/mountain_car_eval.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from zoo.classic_control.mountain_car.config.mtcar_muzero_config import main_config, create_config | ||
from lzero.entry import eval_muzero | ||
import numpy as np | ||
|
||
if __name__ == "__main__": | ||
""" | ||
Entry point for the evaluation of the MuZero model on the CartPole environment. | ||
Variables: | ||
- model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the | ||
pretrained model. An absolute path is recommended. In LightZero, the path is usually something like | ||
``exp_name/ckpt/ckpt_best.pth.tar``. | ||
- returns_mean_seeds (:obj:`List[float]`): List to store the mean returns for each seed. | ||
- returns_seeds (:obj:`List[float]`): List to store the returns for each seed. | ||
- seeds (:obj:`List[int]`): List of seeds for the environment. | ||
- num_episodes_each_seed (:obj:`int`): Number of episodes to run for each seed. | ||
- total_test_episodes (:obj:`int`): Total number of test episodes, computed as the product of the number of | ||
seeds and the number of episodes per seed. | ||
""" | ||
# model_path = "./ckpt/ckpt_best.pth.tar" | ||
model_path = None | ||
returns_mean_seeds = [] | ||
returns_seeds = [] | ||
seeds = [0] | ||
num_episodes_each_seed = 2 | ||
total_test_episodes = num_episodes_each_seed * len(seeds) | ||
create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base | ||
main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 | ||
main_config.env.n_evaluator_episode = total_test_episodes | ||
main_config.env.replay_path = './video' | ||
main_config.exp_name = f'lz_result/eval/muzero_eval_ls{main_config.policy.model.latent_state_dim}' | ||
|
||
for seed in seeds: | ||
""" | ||
- returns_mean (:obj:`float`): The mean return of the evaluation. | ||
- returns (:obj:`List[float]`): The returns of the evaluation. | ||
""" | ||
returns_mean, returns, trajectorys = eval_muzero( | ||
[main_config, create_config], | ||
seed=seed, | ||
num_episodes_each_seed=num_episodes_each_seed, | ||
print_seed_details=False, | ||
model_path=model_path | ||
) | ||
returns_mean_seeds.append(returns_mean) | ||
returns_seeds.append(returns) | ||
|
||
returns_mean_seeds = np.array(returns_mean_seeds) | ||
returns_seeds = np.array(returns_seeds) | ||
|
||
# Print evaluation results | ||
print("=" * 20) | ||
print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).") | ||
print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") | ||
print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) | ||
print("=" * 20) |
Oops, something went wrong.