Skip to content

Commit e6664fe

Browse files
committed
env with competition tasks
1 parent 79eac50 commit e6664fe

File tree

8 files changed

+102
-27
lines changed

8 files changed

+102
-27
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.idea
22
.vscode
33
.DS_Store
4+
docs/.DS_Store
45
**/.DS_Store
56
logs/
67
videos/

docs/.DS_Store

0 Bytes
Binary file not shown.

xuance/environment/utils/base.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,25 @@ def __init__(self, *args, **kwargs):
118118
self.num_agents: Optional[int] = None # Number of all agents, e.g., 4.
119119
self.max_episode_steps: Optional[int] = None
120120

121-
def get_env_info(self):
121+
def get_env_info(self) -> Dict[str, Any]:
122122
return {'state_space': self.state_space,
123123
'observation_space': self.observation_space,
124124
'action_space': self.action_space,
125125
'agents': self.agents,
126126
'num_agents': self.num_agents,
127127
'max_episode_steps': self.max_episode_steps}
128128

129+
def get_groups_info(self) -> Dict[str, Any]:
130+
agent_groups: List[AgentKeys] = [] # e.g., [['red_0', 'red_1'], ['blue_0', 'blue_1']]. Default is empty.
131+
num_groups: int = 1 # The number of groups.
132+
return {'num_groups': num_groups,
133+
'agent_groups': agent_groups,
134+
'observation_space_groups': [{k: self.observation_space[k] for i, k in enumerate(group)}
135+
for group in agent_groups],
136+
'action_space_groups': [{k: self.action_space[k] for i, k in enumerate(group)}
137+
for group in agent_groups],
138+
'num_agents_groups': [len(group) for group in agent_groups]}
139+
129140
def agent_mask(self):
130141
"""Returns boolean mask variables indicating which agents are currently alive."""
131142
return {agent: True for agent in self.agents}

xuance/environment/utils/wrapper.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, env, **kwargs):
131131
self.num_agents = self.env.num_agents # Number of all agents, e.g., 4.
132132
self._episode_score = {agent: 0.0 for agent in self.agents}
133133
self.env_info = self.env.get_env_info()
134+
self.groups_info = self.env.get_groups_info()
134135

135136
def reset(self, **kwargs) -> Tuple[dict, dict]:
136137
"""Resets the environment with kwargs."""

xuance/environment/vector_envs/dummy/dummy_vec_maenv.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(self, env_fns, env_seed):
2121
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
2222

2323
self.env_info = env.env_info
24+
self.groups_info = env.groups_info
2425
self.agents = env.agents
2526
self.num_agents = env.num_agents
2627
self.state_space = env.state_space # Type: Box

xuance/environment/vector_envs/subprocess/subproc_vec_maenv.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def step_env(env, action):
3737
elif cmd == 'get_env_info':
3838
env_info = envs[0].env_info
3939
remote.send(CloudpickleWrapper(env_info))
40+
elif cmd == 'get_groups_info':
41+
env_info = envs[0].groups_info
42+
remote.send(CloudpickleWrapper(env_info))
4043
else:
4144
raise NotImplementedError
4245
except KeyboardInterrupt:
@@ -99,6 +102,8 @@ def __init__(self, env_fns, env_seed, context='spawn', in_series=1):
99102

100103
self.actions = None
101104
self.max_episode_steps = self.env_info['max_episode_steps']
105+
self.remotes[0].send(('get_groups_info', None))
106+
self.groups_info = self.remotes[0].recv().x
102107

103108
def reset(self):
104109
self._assert_not_closed()

xuance/torch/agents/base/agents_marl.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def __init__(self,
5252

5353
# Environment attributes.
5454
self.envs = envs
55-
self.envs.reset()
55+
try:
56+
self.envs.reset()
57+
except:
58+
pass
5659
self.n_agents = self.config.n_agents = envs.num_agents
5760
self.render = config.render
5861
self.fps = config.fps
+78-25
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,124 @@
11
import os
2-
import copy
2+
import argparse
3+
from copy import deepcopy
34
import numpy as np
4-
from xuance.torch.runners import RunnerBase
55
from xuance.torch.agents import REGISTRY_Agents
66
from xuance.environment import make_envs
7+
from xuance.torch.utils.operations import set_seed
78

89

9-
class RunnerMARL(object):
10-
def __init__(self, config):
11-
super().__init__()
12-
13-
self.agents = REGISTRY_Agents[config.agent](config, self.envs)
14-
self.config = config
10+
class RunnerCompetition(object):
11+
def __init__(self, configs):
12+
self.configs = configs
13+
# set random seeds
14+
set_seed(configs.seed)
1515

16-
if self.agents.distributed_training:
16+
# build environments
17+
self.envs = make_envs(self.configs[0])
18+
self.envs.reset()
19+
self.group_info = self.envs.groups_infos
20+
self.groups = self.group_info['agent_groups']
21+
self.num_groups = self.group_info['num_groups']
22+
self.obs_space_groups = self.group_info['observation_space_groups']
23+
self.act_space_groups = self.group_info['action_space_groups']
24+
assert len(configs) == self.num_groups, "Number of groups must be equal to the number of methods."
25+
self.agents = []
26+
for group in range(self.num_groups):
27+
_env_info = dict(num_agents=len(self.groups[group]),
28+
num_envs=self.envs.num_envs,
29+
agents=self.groups[group],
30+
state_space=self.envs.state_space,
31+
observation_space=self.obs_space_groups[group],
32+
action_space=self.act_space_groups[group],
33+
max_episode_steps=self.envs.max_episode_steps)
34+
_env = argparse.Namespace(**_env_info)
35+
self.agents.append(REGISTRY_Agents[self.configs[group].agent](self.configs[group], _env))
36+
37+
self.observation_space = self.envs.observation_space
38+
self.n_envs = self.envs.num_envs
39+
self.rank = 0
40+
if self.agents[0].distributed_training:
1741
self.rank = int(os.environ['RANK'])
1842

43+
def rprint(self, info: str):
44+
if self.rank == 0:
45+
print(info)
46+
1947
def run(self):
20-
if self.config.test_mode:
48+
if self.configs[0].test_mode:
2149
def env_fn():
22-
config_test = copy.deepcopy(self.config)
50+
config_test = deepcopy(self.configs[0])
2351
config_test.parallels = 1
2452
config_test.render = True
2553
return make_envs(config_test)
26-
self.agents.render = True
27-
self.agents.load_model(self.agents.model_dir_load)
54+
55+
for agent in self.agents:
56+
agent.render = True
57+
agent.load_model(agent.model_dir_load)
58+
59+
# ... Here is test ...
2860
scores = self.agents.test(env_fn, self.config.test_episode)
61+
2962
print(f"Mean Score: {np.mean(scores)}, Std: {np.std(scores)}")
3063
print("Finish testing.")
3164
else:
32-
n_train_steps = self.config.running_steps // self.n_envs
65+
n_train_steps = self.configs[0].running_steps // self.n_envs
66+
67+
# ... Here is train ...
3368
self.agents.train(n_train_steps)
69+
3470
print("Finish training.")
35-
self.agents.save_model("final_train_model.pth")
71+
for agent in self.agents:
72+
agent.save_model("final_train_model.pth")
3673

37-
self.envs.close()
38-
self.agents.finish()
74+
for agent in self.agents:
75+
agent.finish()
3976

4077
def benchmark(self):
4178
def env_fn():
42-
config_test = copy.deepcopy(self.config)
79+
config_test = deepcopy(self.configs[0])
4380
config_test.parallels = 1 # config_test.test_episode
4481
return make_envs(config_test)
4582

46-
train_steps = self.config.running_steps // self.n_envs
47-
eval_interval = self.config.eval_interval // self.n_envs
48-
test_episode = self.config.test_episode
83+
train_steps = self.configs[0].running_steps // self.n_envs
84+
eval_interval = self.configs[0].eval_interval // self.n_envs
85+
test_episode = self.configs[0].test_episode
4986
num_epoch = int(train_steps / eval_interval)
5087

88+
# ... Here is test ...
5189
test_scores = self.agents.test(env_fn, test_episode) if self.rank == 0 else 0.0
90+
5291
best_scores_info = {"mean": np.mean(test_scores),
5392
"std": np.std(test_scores),
54-
"step": self.agents.current_step}
93+
"step": self.agents[0].current_step}
94+
5595
for i_epoch in range(num_epoch):
5696
print("Epoch: %d/%d:" % (i_epoch, num_epoch))
97+
98+
# ... Here is train ...
5799
self.agents.train(eval_interval)
100+
58101
if self.rank == 0:
102+
103+
# ... Here is test ...
59104
test_scores = self.agents.test(env_fn, test_episode)
60105

61106
if np.mean(test_scores) > best_scores_info["mean"]:
62107
best_scores_info = {"mean": np.mean(test_scores),
63108
"std": np.std(test_scores),
64109
"step": self.agents.current_step}
65110
# save best model
66-
self.agents.save_model(model_name="best_model.pth")
111+
for agent in self.agents:
112+
agent.save_model(model_name="best_model.pth")
67113

68114
# end benchmarking
69115
print("Best Model Score: %.2f, std=%.2f" % (best_scores_info["mean"], best_scores_info["std"]))
70-
self.envs.close()
71-
self.agents.finish()
116+
for agent in self.agents:
117+
agent.finish()
118+
119+
def train(self, eval_interval):
120+
return
121+
122+
def test(self, env_fn, test_episode):
123+
scores = [None for handel in self.handles]
124+
return scores

0 commit comments

Comments
 (0)