-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathrun_psro.py
102 lines (85 loc) · 3.47 KB
/
run_psro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# MIT License
# Copyright (c) 2021 MARL @ SJTU
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# pragma: no cover
from argparse import ArgumentParser
import os
import time
from malib.runner import run
from malib.agent import IndependentAgent
from malib.scenarios.psro_scenario import PSROScenario
from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG
from malib.rollout.envs.open_spiel import env_desc_gen
if __name__ == "__main__":
parser = ArgumentParser("Naive Policy Space Response Oracles.")
parser.add_argument("--log_dir", default="./logs/", help="Log directory.")
parser.add_argument(
"--env_id", default="kuhn_poker", help="open_spiel environment id."
)
args = parser.parse_args()
trainer_config = DEFAULT_CONFIG["training_config"].copy()
trainer_config["total_timesteps"] = int(1e6)
training_config = {
"type": IndependentAgent,
"trainer_config": trainer_config,
"custom_config": {},
}
rollout_config = {
"fragment_length": 2000, # every thread
"max_step": 200,
"num_eval_episodes": 10,
"num_threads": 2,
"num_env_per_thread": 10,
"num_eval_threads": 1,
"use_subproc_env": False,
"batch_mode": "time_step",
"postprocessor_types": ["defaults"],
# every # rollout epoch run evaluation.
"eval_interval": 1,
"inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray`
}
agent_mapping_func = lambda agent: agent
algorithms = {
"default": (
DQNPolicy,
DQNTrainer,
# model configuration, None for default
{},
{},
)
}
env_description = env_desc_gen(env_id=args.env_id)
runtime_logdir = os.path.join(args.log_dir, f"psro_{args.env_id}/{time.time()}")
if not os.path.exists(runtime_logdir):
os.makedirs(runtime_logdir)
scenario = PSROScenario(
name=f"psro_{args.env_id}",
log_dir=runtime_logdir,
algorithms=algorithms,
env_description=env_description,
training_config=training_config,
rollout_config=rollout_config,
# control the outer loop.
global_stopping_conditions={"max_iteration": 50},
agent_mapping_func=agent_mapping_func,
# for the training of best response.
stopping_conditions={
"training": {"max_iteration": int(1e4)},
"rollout": {"max_iteration": 100},
},
)
run(scenario)