|
1 | 1 | import os
|
2 |
| -import copy |
| 2 | +import argparse |
| 3 | +from copy import deepcopy |
3 | 4 | import numpy as np
|
4 |
| -from xuance.torch.runners import RunnerBase |
5 | 5 | from xuance.torch.agents import REGISTRY_Agents
|
6 | 6 | from xuance.environment import make_envs
|
| 7 | +from xuance.torch.utils.operations import set_seed |
7 | 8 |
|
8 | 9 |
|
9 |
| -class RunnerMARL(object): |
10 |
| - def __init__(self, config): |
11 |
| - super().__init__() |
12 |
| - |
13 |
| - self.agents = REGISTRY_Agents[config.agent](config, self.envs) |
14 |
| - self.config = config |
| 10 | +class RunnerCompetition(object): |
| 11 | + def __init__(self, configs): |
| 12 | + self.configs = configs |
| 13 | + # set random seeds |
| 14 | + set_seed(configs.seed) |
15 | 15 |
|
16 |
| - if self.agents.distributed_training: |
| 16 | + # build environments |
| 17 | + self.envs = make_envs(self.configs[0]) |
| 18 | + self.envs.reset() |
| 19 | + self.group_info = self.envs.groups_infos |
| 20 | + self.groups = self.group_info['agent_groups'] |
| 21 | + self.num_groups = self.group_info['num_groups'] |
| 22 | + self.obs_space_groups = self.group_info['observation_space_groups'] |
| 23 | + self.act_space_groups = self.group_info['action_space_groups'] |
| 24 | + assert len(configs) == self.num_groups, "Number of groups must be equal to the number of methods." |
| 25 | + self.agents = [] |
| 26 | + for group in range(self.num_groups): |
| 27 | + _env_info = dict(num_agents=len(self.groups[group]), |
| 28 | + num_envs=self.envs.num_envs, |
| 29 | + agents=self.groups[group], |
| 30 | + state_space=self.envs.state_space, |
| 31 | + observation_space=self.obs_space_groups[group], |
| 32 | + action_space=self.act_space_groups[group], |
| 33 | + max_episode_steps=self.envs.max_episode_steps) |
| 34 | + _env = argparse.Namespace(**_env_info) |
| 35 | + self.agents.append(REGISTRY_Agents[self.configs[group].agent](self.configs[group], _env)) |
| 36 | + |
| 37 | + self.observation_space = self.envs.observation_space |
| 38 | + self.n_envs = self.envs.num_envs |
| 39 | + self.rank = 0 |
| 40 | + if self.agents[0].distributed_training: |
17 | 41 | self.rank = int(os.environ['RANK'])
|
18 | 42 |
|
| 43 | + def rprint(self, info: str): |
| 44 | + if self.rank == 0: |
| 45 | + print(info) |
| 46 | + |
19 | 47 | def run(self):
|
20 |
| - if self.config.test_mode: |
| 48 | + if self.configs[0].test_mode: |
21 | 49 | def env_fn():
|
22 |
| - config_test = copy.deepcopy(self.config) |
| 50 | + config_test = deepcopy(self.configs[0]) |
23 | 51 | config_test.parallels = 1
|
24 | 52 | config_test.render = True
|
25 | 53 | return make_envs(config_test)
|
26 |
| - self.agents.render = True |
27 |
| - self.agents.load_model(self.agents.model_dir_load) |
| 54 | + |
| 55 | + for agent in self.agents: |
| 56 | + agent.render = True |
| 57 | + agent.load_model(agent.model_dir_load) |
| 58 | + |
| 59 | + # ... Here is test ... |
28 | 60 | scores = self.agents.test(env_fn, self.config.test_episode)
|
| 61 | + |
29 | 62 | print(f"Mean Score: {np.mean(scores)}, Std: {np.std(scores)}")
|
30 | 63 | print("Finish testing.")
|
31 | 64 | else:
|
32 |
| - n_train_steps = self.config.running_steps // self.n_envs |
| 65 | + n_train_steps = self.configs[0].running_steps // self.n_envs |
| 66 | + |
| 67 | + # ... Here is train ... |
33 | 68 | self.agents.train(n_train_steps)
|
| 69 | + |
34 | 70 | print("Finish training.")
|
35 |
| - self.agents.save_model("final_train_model.pth") |
| 71 | + for agent in self.agents: |
| 72 | + agent.save_model("final_train_model.pth") |
36 | 73 |
|
37 |
| - self.envs.close() |
38 |
| - self.agents.finish() |
| 74 | + for agent in self.agents: |
| 75 | + agent.finish() |
39 | 76 |
|
40 | 77 | def benchmark(self):
|
41 | 78 | def env_fn():
|
42 |
| - config_test = copy.deepcopy(self.config) |
| 79 | + config_test = deepcopy(self.configs[0]) |
43 | 80 | config_test.parallels = 1 # config_test.test_episode
|
44 | 81 | return make_envs(config_test)
|
45 | 82 |
|
46 |
| - train_steps = self.config.running_steps // self.n_envs |
47 |
| - eval_interval = self.config.eval_interval // self.n_envs |
48 |
| - test_episode = self.config.test_episode |
| 83 | + train_steps = self.configs[0].running_steps // self.n_envs |
| 84 | + eval_interval = self.configs[0].eval_interval // self.n_envs |
| 85 | + test_episode = self.configs[0].test_episode |
49 | 86 | num_epoch = int(train_steps / eval_interval)
|
50 | 87 |
|
| 88 | + # ... Here is test ... |
51 | 89 | test_scores = self.agents.test(env_fn, test_episode) if self.rank == 0 else 0.0
|
| 90 | + |
52 | 91 | best_scores_info = {"mean": np.mean(test_scores),
|
53 | 92 | "std": np.std(test_scores),
|
54 |
| - "step": self.agents.current_step} |
| 93 | + "step": self.agents[0].current_step} |
| 94 | + |
55 | 95 | for i_epoch in range(num_epoch):
|
56 | 96 | print("Epoch: %d/%d:" % (i_epoch, num_epoch))
|
| 97 | + |
| 98 | + # ... Here is train ... |
57 | 99 | self.agents.train(eval_interval)
|
| 100 | + |
58 | 101 | if self.rank == 0:
|
| 102 | + |
| 103 | + # ... Here is test ... |
59 | 104 | test_scores = self.agents.test(env_fn, test_episode)
|
60 | 105 |
|
61 | 106 | if np.mean(test_scores) > best_scores_info["mean"]:
|
62 | 107 | best_scores_info = {"mean": np.mean(test_scores),
|
63 | 108 | "std": np.std(test_scores),
|
64 | 109 | "step": self.agents.current_step}
|
65 | 110 | # save best model
|
66 |
| - self.agents.save_model(model_name="best_model.pth") |
| 111 | + for agent in self.agents: |
| 112 | + agent.save_model(model_name="best_model.pth") |
67 | 113 |
|
68 | 114 | # end benchmarking
|
69 | 115 | print("Best Model Score: %.2f, std=%.2f" % (best_scores_info["mean"], best_scores_info["std"]))
|
70 |
| - self.envs.close() |
71 |
| - self.agents.finish() |
| 116 | + for agent in self.agents: |
| 117 | + agent.finish() |
| 118 | + |
| 119 | + def train(self, eval_interval): |
| 120 | + return |
| 121 | + |
| 122 | + def test(self, env_fn, test_episode): |
| 123 | + scores = [None for handel in self.handles] |
| 124 | + return scores |
0 commit comments