diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 01102be2..ea1f7112 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -86,16 +86,6 @@ def create_slime_gym_env(**kwargs): env = gym.make(name, **kwargs) return env -def create_connect_four_env(**kwargs): - from rl_games.envs.connect4_selfplay import ConnectFourSelfPlay - name = kwargs.pop('name') - limit_steps = kwargs.pop('limit_steps', False) - self_play = kwargs.pop('self_play', False) - if self_play: - env = ConnectFourSelfPlay(name, **kwargs) - else: - env = gym.make(name, **kwargs) - return env def create_atari_gym_env(**kwargs): #frames = kwargs.pop('frames', 1) @@ -391,10 +381,6 @@ def create_env(name, **kwargs): 'env_creator' : lambda **kwargs : create_minigrid_env(kwargs.pop('name'), **kwargs), 'vecenv_type' : 'RAY' }, - 'connect4_env' : { - 'env_creator' : lambda **kwargs : create_connect_four_env(**kwargs), - 'vecenv_type' : 'RAY' - }, 'multiwalker_env' : { 'env_creator' : lambda **kwargs : create_multiwalker_env(**kwargs), 'vecenv_type' : 'RAY' diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index 7b895eef..9cc880a6 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -20,7 +20,7 @@ def __init__(self, size, ob_space): self._next_obses = np.zeros((size,) + ob_space.shape, dtype=ob_space.dtype) self._rewards = np.zeros(size) self._actions = np.zeros(size, dtype=np.int32) - self._dones = np.zeros(size, dtype=np.bool) + self._dones = np.zeros(size, dtype=bool) self._maxsize = size self._next_idx = 0 @@ -341,7 +341,7 @@ def _init_from_env_info(self, env_info): if self.is_discrete or self.is_multi_discrete: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=int), obs_base_shape) if self.use_action_masks: - self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=np.bool), obs_base_shape) + self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=bool), obs_base_shape) if self.is_continuous: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape) self.tensor_dict['mus'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape) diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index 646da555..01016723 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -7,7 +7,6 @@ from time import sleep import torch - class RayWorker: def __init__(self, config_name, config): self.env = configurations[config_name]['env_creator'](**config) @@ -96,14 +95,16 @@ def get_env_info(self): class RayVecEnv(IVecEnv): + import ray + def __init__(self, config_name, num_actors, **kwargs): self.config_name = config_name self.num_actors = num_actors self.use_torch = False self.seed = kwargs.pop('seed', None) - import ray - self.remote_worker = ray.remote(RayWorker) + + self.remote_worker = self.ray.remote(RayWorker) self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)] if self.seed is not None: @@ -111,15 +112,15 @@ def __init__(self, config_name, num_actors, **kwargs): seed_set = [] for (seed, worker) in zip(seeds, self.workers): seed_set.append(worker.seed.remote(seed)) - ray.get(seed_set) + self.ray.get(seed_set) res = self.workers[0].get_number_of_agents.remote() - self.num_agents = ray.get(res) + self.num_agents = self.ray.get(res) res = self.workers[0].get_env_info.remote() - env_info = ray.get(res) + env_info = self.ray.get(res) res = self.workers[0].can_concat_infos.remote() - can_concat_infos = ray.get(res) + can_concat_infos = self.ray.get(res) self.use_global_obs = env_info['use_global_observations'] self.concat_infos = can_concat_infos self.obs_type_dict = type(env_info.get('observation_space')) is gym.spaces.Dict @@ -139,7 +140,7 @@ def step(self, actions): for num, worker in enumerate(self.workers): res_obs.append(worker.step.remote(actions[self.num_agents * num: self.num_agents * num + self.num_agents])) - all_res = ray.get(res_obs) + all_res = self.ray.get(res_obs) for res in all_res: cobs, crewards, cdones, cinfos = res if self.use_global_obs: @@ -171,27 +172,27 @@ def step(self, actions): def get_env_info(self): res = self.workers[0].get_env_info.remote() - return ray.get(res) + return self.ray.get(res) def set_weights(self, indices, weights): res = [] for ind in indices: res.append(self.workers[ind].set_weights.remote(weights)) - ray.get(res) + self.ray.get(res) def has_action_masks(self): return True def get_action_masks(self): mask = [worker.get_action_mask.remote() for worker in self.workers] - masks = ray.get(mask) + masks = self.ray.get(mask) return np.concatenate(masks, axis=0) def reset(self): res_obs = [worker.reset.remote() for worker in self.workers] newobs, newstates = [],[] for res in res_obs: - cobs = ray.get(res) + cobs = self.ray.get(res) if self.use_global_obs: newobs.append(cobs["obs"]) newstates.append(cobs["state"]) diff --git a/rl_games/configs/smac/10m_vs_11m_torch.yaml b/rl_games/configs/smac/v1/10m_vs_11m_torch.yaml similarity index 100% rename from rl_games/configs/smac/10m_vs_11m_torch.yaml rename to rl_games/configs/smac/v1/10m_vs_11m_torch.yaml diff --git a/rl_games/configs/smac/27m_vs_30m_cv.yaml b/rl_games/configs/smac/v1/27m_vs_30m_cv.yaml similarity index 100% rename from rl_games/configs/smac/27m_vs_30m_cv.yaml rename to rl_games/configs/smac/v1/27m_vs_30m_cv.yaml diff --git a/rl_games/configs/smac/27m_vs_30m_torch.yaml b/rl_games/configs/smac/v1/27m_vs_30m_torch.yaml similarity index 100% rename from rl_games/configs/smac/27m_vs_30m_torch.yaml rename to rl_games/configs/smac/v1/27m_vs_30m_torch.yaml diff --git a/rl_games/configs/smac/2m_vs_1z.yaml b/rl_games/configs/smac/v1/2m_vs_1z.yaml similarity index 100% rename from rl_games/configs/smac/2m_vs_1z.yaml rename to rl_games/configs/smac/v1/2m_vs_1z.yaml diff --git a/rl_games/configs/smac/2m_vs_1z_torch.yaml b/rl_games/configs/smac/v1/2m_vs_1z_torch.yaml similarity index 100% rename from rl_games/configs/smac/2m_vs_1z_torch.yaml rename to rl_games/configs/smac/v1/2m_vs_1z_torch.yaml diff --git a/rl_games/configs/smac/2s_vs_1c.yaml b/rl_games/configs/smac/v1/2s_vs_1c.yaml similarity index 100% rename from rl_games/configs/smac/2s_vs_1c.yaml rename to rl_games/configs/smac/v1/2s_vs_1c.yaml diff --git a/rl_games/configs/smac/3m_cnn_torch.yaml b/rl_games/configs/smac/v1/3m_cnn_torch.yaml similarity index 100% rename from rl_games/configs/smac/3m_cnn_torch.yaml rename to rl_games/configs/smac/v1/3m_cnn_torch.yaml diff --git a/rl_games/configs/smac/3m_torch.yaml b/rl_games/configs/smac/v1/3m_torch.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch.yaml rename to rl_games/configs/smac/v1/3m_torch.yaml diff --git a/rl_games/configs/smac/3m_torch_cv.yaml b/rl_games/configs/smac/v1/3m_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_cv.yaml rename to rl_games/configs/smac/v1/3m_torch_cv.yaml diff --git a/rl_games/configs/smac/3m_torch_cv_joint.yaml b/rl_games/configs/smac/v1/3m_torch_cv_joint.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_cv_joint.yaml rename to rl_games/configs/smac/v1/3m_torch_cv_joint.yaml diff --git a/rl_games/configs/smac/3m_torch_cv_rnn.yaml b/rl_games/configs/smac/v1/3m_torch_cv_rnn.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_cv_rnn.yaml rename to rl_games/configs/smac/v1/3m_torch_cv_rnn.yaml diff --git a/rl_games/configs/smac/3m_torch_rnn.yaml b/rl_games/configs/smac/v1/3m_torch_rnn.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_rnn.yaml rename to rl_games/configs/smac/v1/3m_torch_rnn.yaml diff --git a/rl_games/configs/smac/3m_torch_sparse.yaml b/rl_games/configs/smac/v1/3m_torch_sparse.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_sparse.yaml rename to rl_games/configs/smac/v1/3m_torch_sparse.yaml diff --git a/rl_games/configs/smac/3s5z_vs_3s6z_torch.yaml b/rl_games/configs/smac/v1/3s5z_vs_3s6z_torch.yaml similarity index 100% rename from rl_games/configs/smac/3s5z_vs_3s6z_torch.yaml rename to rl_games/configs/smac/v1/3s5z_vs_3s6z_torch.yaml diff --git a/rl_games/configs/smac/3s5z_vs_3s6z_torch_cv.yaml b/rl_games/configs/smac/v1/3s5z_vs_3s6z_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/3s5z_vs_3s6z_torch_cv.yaml rename to rl_games/configs/smac/v1/3s5z_vs_3s6z_torch_cv.yaml diff --git a/rl_games/configs/smac/3s_vs_4z.yaml b/rl_games/configs/smac/v1/3s_vs_4z.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_4z.yaml rename to rl_games/configs/smac/v1/3s_vs_4z.yaml diff --git a/rl_games/configs/smac/3s_vs_5z.yaml b/rl_games/configs/smac/v1/3s_vs_5z.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z.yaml rename to rl_games/configs/smac/v1/3s_vs_5z.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_cv.yaml b/rl_games/configs/smac/v1/3s_vs_5z_cv.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_cv.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_cv.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_cv_rnn.yaml b/rl_games/configs/smac/v1/3s_vs_5z_cv_rnn.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_cv_rnn.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_cv_rnn.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_torch_lstm.yaml b/rl_games/configs/smac/v1/3s_vs_5z_torch_lstm.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_torch_lstm.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_torch_lstm.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_torch_lstm2.yaml b/rl_games/configs/smac/v1/3s_vs_5z_torch_lstm2.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_torch_lstm2.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_torch_lstm2.yaml diff --git a/rl_games/configs/smac/5m_vs_6m_rnn.yaml b/rl_games/configs/smac/v1/5m_vs_6m_rnn.yaml similarity index 100% rename from rl_games/configs/smac/5m_vs_6m_rnn.yaml rename to rl_games/configs/smac/v1/5m_vs_6m_rnn.yaml diff --git a/rl_games/configs/smac/5m_vs_6m_rnn_cv.yaml b/rl_games/configs/smac/v1/5m_vs_6m_rnn_cv.yaml similarity index 100% rename from rl_games/configs/smac/5m_vs_6m_rnn_cv.yaml rename to rl_games/configs/smac/v1/5m_vs_6m_rnn_cv.yaml diff --git a/rl_games/configs/smac/5m_vs_6m_torch.yaml b/rl_games/configs/smac/v1/5m_vs_6m_torch.yaml similarity index 100% rename from rl_games/configs/smac/5m_vs_6m_torch.yaml rename to rl_games/configs/smac/v1/5m_vs_6m_torch.yaml diff --git a/rl_games/configs/smac/6h_vs_8z_torch.yaml b/rl_games/configs/smac/v1/6h_vs_8z_torch.yaml similarity index 100% rename from rl_games/configs/smac/6h_vs_8z_torch.yaml rename to rl_games/configs/smac/v1/6h_vs_8z_torch.yaml diff --git a/rl_games/configs/smac/6h_vs_8z_torch_cv.yaml b/rl_games/configs/smac/v1/6h_vs_8z_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/6h_vs_8z_torch_cv.yaml rename to rl_games/configs/smac/v1/6h_vs_8z_torch_cv.yaml diff --git a/rl_games/configs/smac/8m_torch.yaml b/rl_games/configs/smac/v1/8m_torch.yaml similarity index 100% rename from rl_games/configs/smac/8m_torch.yaml rename to rl_games/configs/smac/v1/8m_torch.yaml diff --git a/rl_games/configs/smac/8m_torch_cv.yaml b/rl_games/configs/smac/v1/8m_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/8m_torch_cv.yaml rename to rl_games/configs/smac/v1/8m_torch_cv.yaml diff --git a/rl_games/configs/smac/MMM2_torch.yaml b/rl_games/configs/smac/v1/MMM2_torch.yaml similarity index 100% rename from rl_games/configs/smac/MMM2_torch.yaml rename to rl_games/configs/smac/v1/MMM2_torch.yaml diff --git a/rl_games/configs/smac/corridor_torch.yaml b/rl_games/configs/smac/v1/corridor_torch.yaml similarity index 100% rename from rl_games/configs/smac/corridor_torch.yaml rename to rl_games/configs/smac/v1/corridor_torch.yaml diff --git a/rl_games/configs/smac/corridor_torch_cv.yaml b/rl_games/configs/smac/v1/corridor_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/corridor_torch_cv.yaml rename to rl_games/configs/smac/v1/corridor_torch_cv.yaml diff --git a/rl_games/configs/smac/runs/2c_vs_64zg.yaml b/rl_games/configs/smac/v1/runs/2c_vs_64zg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2c_vs_64zg.yaml rename to rl_games/configs/smac/v1/runs/2c_vs_64zg.yaml diff --git a/rl_games/configs/smac/runs/2c_vs_64zg_neg.yaml b/rl_games/configs/smac/v1/runs/2c_vs_64zg_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2c_vs_64zg_neg.yaml rename to rl_games/configs/smac/v1/runs/2c_vs_64zg_neg.yaml diff --git a/rl_games/configs/smac/runs/2s3z.yaml b/rl_games/configs/smac/v1/runs/2s3z.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s3z.yaml rename to rl_games/configs/smac/v1/runs/2s3z.yaml diff --git a/rl_games/configs/smac/runs/2s3z_neg.yaml b/rl_games/configs/smac/v1/runs/2s3z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s3z_neg.yaml rename to rl_games/configs/smac/v1/runs/2s3z_neg.yaml diff --git a/rl_games/configs/smac/runs/2s_vs_1c.yaml b/rl_games/configs/smac/v1/runs/2s_vs_1c.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s_vs_1c.yaml rename to rl_games/configs/smac/v1/runs/2s_vs_1c.yaml diff --git a/rl_games/configs/smac/runs/2s_vs_1c_neg.yaml b/rl_games/configs/smac/v1/runs/2s_vs_1c_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s_vs_1c_neg.yaml rename to rl_games/configs/smac/v1/runs/2s_vs_1c_neg.yaml diff --git a/rl_games/configs/smac/runs/3s5z.yaml b/rl_games/configs/smac/v1/runs/3s5z.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s5z.yaml rename to rl_games/configs/smac/v1/runs/3s5z.yaml diff --git a/rl_games/configs/smac/runs/3s5z_neg.yaml b/rl_games/configs/smac/v1/runs/3s5z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s5z_neg.yaml rename to rl_games/configs/smac/v1/runs/3s5z_neg.yaml diff --git a/rl_games/configs/smac/runs/3s_vs_5z.yaml b/rl_games/configs/smac/v1/runs/3s_vs_5z.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s_vs_5z.yaml rename to rl_games/configs/smac/v1/runs/3s_vs_5z.yaml diff --git a/rl_games/configs/smac/runs/3s_vs_5z_neg.yaml b/rl_games/configs/smac/v1/runs/3s_vs_5z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s_vs_5z_neg.yaml rename to rl_games/configs/smac/v1/runs/3s_vs_5z_neg.yaml diff --git a/rl_games/configs/smac/runs/3s_vs_5z_neg_joint.yaml b/rl_games/configs/smac/v1/runs/3s_vs_5z_neg_joint.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s_vs_5z_neg_joint.yaml rename to rl_games/configs/smac/v1/runs/3s_vs_5z_neg_joint.yaml diff --git a/rl_games/configs/smac/runs/6h_vs_8z.yaml b/rl_games/configs/smac/v1/runs/6h_vs_8z.yaml similarity index 100% rename from rl_games/configs/smac/runs/6h_vs_8z.yaml rename to rl_games/configs/smac/v1/runs/6h_vs_8z.yaml diff --git a/rl_games/configs/smac/runs/6h_vs_8z_neg.yaml b/rl_games/configs/smac/v1/runs/6h_vs_8z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/6h_vs_8z_neg.yaml rename to rl_games/configs/smac/v1/runs/6h_vs_8z_neg.yaml diff --git a/rl_games/configs/smac/runs/6h_vs_8z_rnn.yaml b/rl_games/configs/smac/v1/runs/6h_vs_8z_rnn.yaml similarity index 100% rename from rl_games/configs/smac/runs/6h_vs_8z_rnn.yaml rename to rl_games/configs/smac/v1/runs/6h_vs_8z_rnn.yaml diff --git a/rl_games/configs/smac/runs/MMM2.yaml b/rl_games/configs/smac/v1/runs/MMM2.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2.yaml rename to rl_games/configs/smac/v1/runs/MMM2.yaml diff --git a/rl_games/configs/smac/runs/MMM2_conv1d.yaml b/rl_games/configs/smac/v1/runs/MMM2_conv1d.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2_conv1d.yaml rename to rl_games/configs/smac/v1/runs/MMM2_conv1d.yaml diff --git a/rl_games/configs/smac/runs/MMM2_neg.yaml b/rl_games/configs/smac/v1/runs/MMM2_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2_neg.yaml rename to rl_games/configs/smac/v1/runs/MMM2_neg.yaml diff --git a/rl_games/configs/smac/runs/MMM2_rnn.yaml b/rl_games/configs/smac/v1/runs/MMM2_rnn.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2_rnn.yaml rename to rl_games/configs/smac/v1/runs/MMM2_rnn.yaml diff --git a/rl_games/configs/smac/runs/bane_vs_bane.yaml b/rl_games/configs/smac/v1/runs/bane_vs_bane.yaml similarity index 100% rename from rl_games/configs/smac/runs/bane_vs_bane.yaml rename to rl_games/configs/smac/v1/runs/bane_vs_bane.yaml diff --git a/rl_games/configs/smac/runs/bane_vs_bane_neg.yaml b/rl_games/configs/smac/v1/runs/bane_vs_bane_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/bane_vs_bane_neg.yaml rename to rl_games/configs/smac/v1/runs/bane_vs_bane_neg.yaml diff --git a/rl_games/configs/smac/runs/corridor_cv.yaml b/rl_games/configs/smac/v1/runs/corridor_cv.yaml similarity index 100% rename from rl_games/configs/smac/runs/corridor_cv.yaml rename to rl_games/configs/smac/v1/runs/corridor_cv.yaml diff --git a/rl_games/configs/smac/runs/corridor_cv_neg.yaml b/rl_games/configs/smac/v1/runs/corridor_cv_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/corridor_cv_neg.yaml rename to rl_games/configs/smac/v1/runs/corridor_cv_neg.yaml diff --git a/rl_games/configs/smac/v2/5z_torch_cv.yaml b/rl_games/configs/smac/v2/5z_torch_cv.yaml new file mode 100644 index 00000000..3a29f28e --- /dev/null +++ b/rl_games/configs/smac/v2/5z_torch_cv.yaml @@ -0,0 +1,73 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: actor_critic + separate: False + #normalization: layer_norm + space: + discrete: + + mlp: + units: [256, 128] + activation: relu + initializer: + name: default + regularizer: + name: 'None' + + config: + name: 5z_cv + reward_shaper: + scale_value: 1 + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 5e-4 + score_to_win: 20 + grad_norm: 0.5 + entropy_coef: 0.001 + truncate_grads: True + env_name: smac + e_clip: 0.2 + clip_value: True + num_actors: 8 + horizon_length: 128 + minibatch_size: 1536 # 3 * 512 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: None + kl_threshold: 0.05 + normalize_input: True + normalize_value: False + use_action_masks: True + ignore_dead_batches : False + + env_config: + name: zerg_5_vs_5 + frames: 1 + transpose: False + random_invalid_step: False + central_value: True + reward_only_positive: True + central_value_config: + minibatch_size: 512 + mini_epochs: 4 + learning_rate: 5e-4 + clip_value: False + normalize_input: True + network: + name: actor_critic + central_value: True + mlp: + units: [256, 128] + activation: relu + initializer: + name: default + scale: 2 + regularizer: + name: 'None' \ No newline at end of file diff --git a/rl_games/distributed/__init__.py b/rl_games/distributed/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index 69a343a6..6883b34a 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,8 +1,6 @@ -from rl_games.envs.connect4_network import ConnectBuilder from rl_games.envs.test_network import TestNetBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('connect4net', ConnectBuilder) model_builder.register_network('testnet', TestNetBuilder) \ No newline at end of file diff --git a/rl_games/envs/connect4_network.py b/rl_games/envs/connect4_network.py deleted file mode 100644 index 4ef52118..00000000 --- a/rl_games/envs/connect4_network.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -class ConvBlock(nn.Module): - def __init__(self): - super(ConvBlock, self).__init__() - self.action_size = 7 - self.conv1 = nn.Conv2d(4, 128, 3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(128) - - def forward(self, s): - s = s['obs'].contiguous() - #s = s.view(-1, 3, 6, 7) # batch_size x channels x board_x x board_y - s = F.relu(self.bn1(self.conv1(s))) - return s - - - -class ResBlock(nn.Module): - def __init__(self, inplanes=128, planes=128, stride=1, downsample=None): - super(ResBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - def forward(self, x): - residual = x - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += residual - out = F.relu(out) - return out - - - -class OutBlock(nn.Module): - def __init__(self): - super(OutBlock, self).__init__() - self.conv = nn.Conv2d(128, 3, kernel_size=1) # value head - self.bn = nn.BatchNorm2d(3) - self.fc1 = nn.Linear(3*6*7, 32) - self.fc2 = nn.Linear(32, 1) - - self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head - self.bn1 = nn.BatchNorm2d(32) - self.fc = nn.Linear(6*7*32, 7) - - def forward(self,s): - v = F.relu(self.bn(self.conv(s))) # value head - v = v.view(-1, 3*6*7) # batch_size X channel X height X width - v = F.relu(self.fc1(v)) - v = F.relu(self.fc2(v)) - v = torch.tanh(v) - - p = F.relu(self.bn1(self.conv1(s))) # policy head - p = p.view(-1, 6*7*32) - p = self.fc(p) - return p, v, None - -class ConnectNet(nn.Module): - def __init__(self, blocks): - super(ConnectNet, self).__init__() - self.blocks = blocks - self.conv = ConvBlock() - for block in range(self.blocks): - setattr(self, "res_%i" % block,ResBlock()) - self.outblock = OutBlock() - def is_rnn(self): - return False - - def forward(self,s): - s = s.permute((0, 3, 1, 2)) - s = self.conv(s) - for block in range(self.blocks): - s = getattr(self, "res_%i" % block)(s) - s = self.outblock(s) - return s - - -from rl_games.algos_torch.network_builder import NetworkBuilder - -class ConnectBuilder(NetworkBuilder): - def __init__(self, **kwargs): - NetworkBuilder.__init__(self) - - def load(self, params): - self.params = params - self.blocks = params['blocks'] - - def build(self, name, **kwargs): - return ConnectNet(self.blocks) - - def __call__(self, name, **kwargs): - return self.build(name, **kwargs) - diff --git a/rl_games/envs/connect4_selfplay.py b/rl_games/envs/connect4_selfplay.py deleted file mode 100644 index 913e1e2a..00000000 --- a/rl_games/envs/connect4_selfplay.py +++ /dev/null @@ -1,132 +0,0 @@ -import gym -import numpy as np -from pettingzoo.classic import connect_four_v0 -import yaml -from rl_games.torch_runner import Runner -import os -from collections import deque - - -class ConnectFourSelfPlay(gym.Env): - def __init__(self, name="connect_four_v0", **kwargs): - gym.Env.__init__(self) - self.name = name - self.is_deterministic = kwargs.pop('is_deterministic', False) - self.is_human = kwargs.pop('is_human', False) - self.random_agent = kwargs.pop('random_agent', False) - self.config_path = kwargs.pop('config_path') - self.agent = None - - self.env = connect_four_v0.env() # gym.make(name, **kwargs) - self.action_space = self.env.action_spaces['player_0'] - observation_space = self.env.observation_spaces['player_0'] - shp = observation_space.shape - self.observation_space = gym.spaces.Box( - low=0, high=1, shape=(shp[:-1] + (shp[-1] * 2,)), dtype=np.uint8) - self.obs_deque = deque([], maxlen=2) - self.agent_id = 0 - - def _get_legal_moves(self, agent_id): - name = 'player_0' if agent_id == 0 else 'player_1' - action_ids = self.env.infos[name]['legal_moves'] - mask = np.zeros(self.action_space.n, dtype=np.bool) - mask[action_ids] = True - return mask, action_ids - - def env_step(self, action): - obs = self.env.step(action) - info = {} - name = 'player_0' if self.agent_id == 0 else 'player_1' - reward = self.env.rewards[name] - done = self.env.dones[name] - return obs, reward, done, info - - def get_obs(self): - return np.concatenate(self.obs_deque, -1).astype(np.uint8) * 255 - - def reset(self): - if self.agent == None: - self.create_agent(self.config_path) - - self.agent_id = np.random.randint(2) - obs = self.env.reset() - self.obs_deque.append(obs) - self.obs_deque.append(obs) - if self.agent_id == 1: - op_obs = self.get_obs() - op_obs = self.agent.obs_to_torch(op_obs) - mask, ids = self._get_legal_moves(0) - if self.is_human: - self.render() - opponent_action = int(input()) - else: - if self.random_agent: - opponent_action = np.random.choice(ids, 1)[0] - else: - opponent_action = self.agent.get_masked_action( - op_obs, mask, self.is_deterministic).item() - - obs, _, _, _ = self.env_step(opponent_action) - - self.obs_deque.append(obs) - return self.get_obs() - - def create_agent(self, config): - with open(config, 'r') as stream: - config = yaml.safe_load(stream) - runner = Runner() - runner.load(config) - config = runner.get_prebuilt_config() - # 'RAYLIB has bug here, CUDA_VISIBLE_DEVICES become unset' - if 'CUDA_VISIBLE_DEVICES' in os.environ: - os.environ.pop('CUDA_VISIBLE_DEVICES') - - self.agent = runner.create_player() - self.agent.model.eval() - - def step(self, action): - - obs, reward, done, info = self.env_step(action) - self.obs_deque.append(obs) - - if done: - if reward == 1: - info['battle_won'] = 1 - else: - info['battle_won'] = 0 - return self.get_obs(), reward, done, info - - op_obs = self.get_obs() - - op_obs = self.agent.obs_to_torch(op_obs) - mask, ids = self._get_legal_moves(1-self.agent_id) - if self.is_human: - self.render() - opponent_action = int(input()) - else: - if self.random_agent: - opponent_action = np.random.choice(ids, 1)[0] - else: - opponent_action = self.agent.get_masked_action( - op_obs, mask, self.is_deterministic).item() - obs, reward, done, _ = self.env_step(opponent_action) - if done: - if reward == -1: - info['battle_won'] = 0 - else: - info['battle_won'] = 1 - self.obs_deque.append(obs) - return self.get_obs(), reward, done, info - - def render(self, mode='ansi'): - self.env.render(mode) - - def update_weights(self, weigths): - self.agent.set_weights(weigths) - - def get_action_mask(self): - mask, _ = self._get_legal_moves(self.agent_id) - return mask - - def has_action_mask(self): - return True diff --git a/rl_games/envs/multiwalker.py b/rl_games/envs/multiwalker.py index de3ab71a..f421bd61 100644 --- a/rl_games/envs/multiwalker.py +++ b/rl_games/envs/multiwalker.py @@ -5,7 +5,6 @@ from rl_games.torch_runner import Runner import os from collections import deque -import rl_games.envs.connect4_network class MultiWalker(gym.Env): def __init__(self, name="multiwalker", **kwargs): diff --git a/rl_games/envs/smac_env.py b/rl_games/envs/smac_env.py index a4039e55..79695743 100644 --- a/rl_games/envs/smac_env.py +++ b/rl_games/envs/smac_env.py @@ -96,7 +96,7 @@ def step(self, actions): return obses, rewards, dones, info def get_action_mask(self): - return np.array(self.env.get_avail_actions(), dtype=np.bool) + return np.array(self.env.get_avail_actions(), dtype=bool) def has_action_mask(self): return not self.random_invalid_step