diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 01102be2..e553deaa 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -86,16 +86,6 @@ def create_slime_gym_env(**kwargs): env = gym.make(name, **kwargs) return env -def create_connect_four_env(**kwargs): - from rl_games.envs.connect4_selfplay import ConnectFourSelfPlay - name = kwargs.pop('name') - limit_steps = kwargs.pop('limit_steps', False) - self_play = kwargs.pop('self_play', False) - if self_play: - env = ConnectFourSelfPlay(name, **kwargs) - else: - env = gym.make(name, **kwargs) - return env def create_atari_gym_env(**kwargs): #frames = kwargs.pop('frames', 1) @@ -171,6 +161,21 @@ def create_smac(name, **kwargs): env = SMACEnv(name, **kwargs) + if frames > 1: + if has_cv: + env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten) + else: + env = wrappers.BatchedFrameStack(env, frames, transpose=False, flatten=flatten) + return env + +def create_smac_v2(name, **kwargs): + from rl_games.envs.smac_v2_env import SMACEnvV2 + frames = kwargs.pop('frames', 1) + transpose = kwargs.pop('transpose', False) + flatten = kwargs.pop('flatten', True) + has_cv = kwargs.get('central_value', False) + env = SMACEnvV2(name, **kwargs) + if frames > 1: if has_cv: env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten) @@ -359,6 +364,10 @@ def create_env(name, **kwargs): 'env_creator' : lambda **kwargs : create_smac(**kwargs), 'vecenv_type' : 'RAY' }, + 'smac_v2' : { + 'env_creator' : lambda **kwargs : create_smac_v2(**kwargs), + 'vecenv_type' : 'RAY' + }, 'smac_cnn' : { 'env_creator' : lambda **kwargs : create_smac_cnn(**kwargs), 'vecenv_type' : 'RAY' @@ -391,10 +400,6 @@ def create_env(name, **kwargs): 'env_creator' : lambda **kwargs : create_minigrid_env(kwargs.pop('name'), **kwargs), 'vecenv_type' : 'RAY' }, - 'connect4_env' : { - 'env_creator' : lambda **kwargs : create_connect_four_env(**kwargs), - 'vecenv_type' : 'RAY' - }, 'multiwalker_env' : { 'env_creator' : lambda **kwargs : create_multiwalker_env(**kwargs), 'vecenv_type' : 'RAY' diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index 7b895eef..9cc880a6 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -20,7 +20,7 @@ def __init__(self, size, ob_space): self._next_obses = np.zeros((size,) + ob_space.shape, dtype=ob_space.dtype) self._rewards = np.zeros(size) self._actions = np.zeros(size, dtype=np.int32) - self._dones = np.zeros(size, dtype=np.bool) + self._dones = np.zeros(size, dtype=bool) self._maxsize = size self._next_idx = 0 @@ -341,7 +341,7 @@ def _init_from_env_info(self, env_info): if self.is_discrete or self.is_multi_discrete: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=int), obs_base_shape) if self.use_action_masks: - self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=np.bool), obs_base_shape) + self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=bool), obs_base_shape) if self.is_continuous: self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape) self.tensor_dict['mus'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape) diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index 646da555..01016723 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -7,7 +7,6 @@ from time import sleep import torch - class RayWorker: def __init__(self, config_name, config): self.env = configurations[config_name]['env_creator'](**config) @@ -96,14 +95,16 @@ def get_env_info(self): class RayVecEnv(IVecEnv): + import ray + def __init__(self, config_name, num_actors, **kwargs): self.config_name = config_name self.num_actors = num_actors self.use_torch = False self.seed = kwargs.pop('seed', None) - import ray - self.remote_worker = ray.remote(RayWorker) + + self.remote_worker = self.ray.remote(RayWorker) self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)] if self.seed is not None: @@ -111,15 +112,15 @@ def __init__(self, config_name, num_actors, **kwargs): seed_set = [] for (seed, worker) in zip(seeds, self.workers): seed_set.append(worker.seed.remote(seed)) - ray.get(seed_set) + self.ray.get(seed_set) res = self.workers[0].get_number_of_agents.remote() - self.num_agents = ray.get(res) + self.num_agents = self.ray.get(res) res = self.workers[0].get_env_info.remote() - env_info = ray.get(res) + env_info = self.ray.get(res) res = self.workers[0].can_concat_infos.remote() - can_concat_infos = ray.get(res) + can_concat_infos = self.ray.get(res) self.use_global_obs = env_info['use_global_observations'] self.concat_infos = can_concat_infos self.obs_type_dict = type(env_info.get('observation_space')) is gym.spaces.Dict @@ -139,7 +140,7 @@ def step(self, actions): for num, worker in enumerate(self.workers): res_obs.append(worker.step.remote(actions[self.num_agents * num: self.num_agents * num + self.num_agents])) - all_res = ray.get(res_obs) + all_res = self.ray.get(res_obs) for res in all_res: cobs, crewards, cdones, cinfos = res if self.use_global_obs: @@ -171,27 +172,27 @@ def step(self, actions): def get_env_info(self): res = self.workers[0].get_env_info.remote() - return ray.get(res) + return self.ray.get(res) def set_weights(self, indices, weights): res = [] for ind in indices: res.append(self.workers[ind].set_weights.remote(weights)) - ray.get(res) + self.ray.get(res) def has_action_masks(self): return True def get_action_masks(self): mask = [worker.get_action_mask.remote() for worker in self.workers] - masks = ray.get(mask) + masks = self.ray.get(mask) return np.concatenate(masks, axis=0) def reset(self): res_obs = [worker.reset.remote() for worker in self.workers] newobs, newstates = [],[] for res in res_obs: - cobs = ray.get(res) + cobs = self.ray.get(res) if self.use_global_obs: newobs.append(cobs["obs"]) newstates.append(cobs["state"]) diff --git a/rl_games/configs/smac/10m_vs_11m_torch.yaml b/rl_games/configs/smac/v1/10m_vs_11m_torch.yaml similarity index 100% rename from rl_games/configs/smac/10m_vs_11m_torch.yaml rename to rl_games/configs/smac/v1/10m_vs_11m_torch.yaml diff --git a/rl_games/configs/smac/27m_vs_30m_cv.yaml b/rl_games/configs/smac/v1/27m_vs_30m_cv.yaml similarity index 100% rename from rl_games/configs/smac/27m_vs_30m_cv.yaml rename to rl_games/configs/smac/v1/27m_vs_30m_cv.yaml diff --git a/rl_games/configs/smac/27m_vs_30m_torch.yaml b/rl_games/configs/smac/v1/27m_vs_30m_torch.yaml similarity index 100% rename from rl_games/configs/smac/27m_vs_30m_torch.yaml rename to rl_games/configs/smac/v1/27m_vs_30m_torch.yaml diff --git a/rl_games/configs/smac/2m_vs_1z.yaml b/rl_games/configs/smac/v1/2m_vs_1z.yaml similarity index 100% rename from rl_games/configs/smac/2m_vs_1z.yaml rename to rl_games/configs/smac/v1/2m_vs_1z.yaml diff --git a/rl_games/configs/smac/2m_vs_1z_torch.yaml b/rl_games/configs/smac/v1/2m_vs_1z_torch.yaml similarity index 100% rename from rl_games/configs/smac/2m_vs_1z_torch.yaml rename to rl_games/configs/smac/v1/2m_vs_1z_torch.yaml diff --git a/rl_games/configs/smac/2s_vs_1c.yaml b/rl_games/configs/smac/v1/2s_vs_1c.yaml similarity index 100% rename from rl_games/configs/smac/2s_vs_1c.yaml rename to rl_games/configs/smac/v1/2s_vs_1c.yaml diff --git a/rl_games/configs/smac/3m_cnn_torch.yaml b/rl_games/configs/smac/v1/3m_cnn_torch.yaml similarity index 100% rename from rl_games/configs/smac/3m_cnn_torch.yaml rename to rl_games/configs/smac/v1/3m_cnn_torch.yaml diff --git a/rl_games/configs/smac/3m_torch.yaml b/rl_games/configs/smac/v1/3m_torch.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch.yaml rename to rl_games/configs/smac/v1/3m_torch.yaml diff --git a/rl_games/configs/smac/3m_torch_cv.yaml b/rl_games/configs/smac/v1/3m_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_cv.yaml rename to rl_games/configs/smac/v1/3m_torch_cv.yaml diff --git a/rl_games/configs/smac/3m_torch_cv_joint.yaml b/rl_games/configs/smac/v1/3m_torch_cv_joint.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_cv_joint.yaml rename to rl_games/configs/smac/v1/3m_torch_cv_joint.yaml diff --git a/rl_games/configs/smac/3m_torch_cv_rnn.yaml b/rl_games/configs/smac/v1/3m_torch_cv_rnn.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_cv_rnn.yaml rename to rl_games/configs/smac/v1/3m_torch_cv_rnn.yaml diff --git a/rl_games/configs/smac/3m_torch_rnn.yaml b/rl_games/configs/smac/v1/3m_torch_rnn.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_rnn.yaml rename to rl_games/configs/smac/v1/3m_torch_rnn.yaml diff --git a/rl_games/configs/smac/3m_torch_sparse.yaml b/rl_games/configs/smac/v1/3m_torch_sparse.yaml similarity index 100% rename from rl_games/configs/smac/3m_torch_sparse.yaml rename to rl_games/configs/smac/v1/3m_torch_sparse.yaml diff --git a/rl_games/configs/smac/3s5z_vs_3s6z_torch.yaml b/rl_games/configs/smac/v1/3s5z_vs_3s6z_torch.yaml similarity index 100% rename from rl_games/configs/smac/3s5z_vs_3s6z_torch.yaml rename to rl_games/configs/smac/v1/3s5z_vs_3s6z_torch.yaml diff --git a/rl_games/configs/smac/3s5z_vs_3s6z_torch_cv.yaml b/rl_games/configs/smac/v1/3s5z_vs_3s6z_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/3s5z_vs_3s6z_torch_cv.yaml rename to rl_games/configs/smac/v1/3s5z_vs_3s6z_torch_cv.yaml diff --git a/rl_games/configs/smac/3s_vs_4z.yaml b/rl_games/configs/smac/v1/3s_vs_4z.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_4z.yaml rename to rl_games/configs/smac/v1/3s_vs_4z.yaml diff --git a/rl_games/configs/smac/3s_vs_5z.yaml b/rl_games/configs/smac/v1/3s_vs_5z.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z.yaml rename to rl_games/configs/smac/v1/3s_vs_5z.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_cv.yaml b/rl_games/configs/smac/v1/3s_vs_5z_cv.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_cv.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_cv.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_cv_rnn.yaml b/rl_games/configs/smac/v1/3s_vs_5z_cv_rnn.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_cv_rnn.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_cv_rnn.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_torch_lstm.yaml b/rl_games/configs/smac/v1/3s_vs_5z_torch_lstm.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_torch_lstm.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_torch_lstm.yaml diff --git a/rl_games/configs/smac/3s_vs_5z_torch_lstm2.yaml b/rl_games/configs/smac/v1/3s_vs_5z_torch_lstm2.yaml similarity index 100% rename from rl_games/configs/smac/3s_vs_5z_torch_lstm2.yaml rename to rl_games/configs/smac/v1/3s_vs_5z_torch_lstm2.yaml diff --git a/rl_games/configs/smac/5m_vs_6m_rnn.yaml b/rl_games/configs/smac/v1/5m_vs_6m_rnn.yaml similarity index 100% rename from rl_games/configs/smac/5m_vs_6m_rnn.yaml rename to rl_games/configs/smac/v1/5m_vs_6m_rnn.yaml diff --git a/rl_games/configs/smac/5m_vs_6m_rnn_cv.yaml b/rl_games/configs/smac/v1/5m_vs_6m_rnn_cv.yaml similarity index 100% rename from rl_games/configs/smac/5m_vs_6m_rnn_cv.yaml rename to rl_games/configs/smac/v1/5m_vs_6m_rnn_cv.yaml diff --git a/rl_games/configs/smac/5m_vs_6m_torch.yaml b/rl_games/configs/smac/v1/5m_vs_6m_torch.yaml similarity index 100% rename from rl_games/configs/smac/5m_vs_6m_torch.yaml rename to rl_games/configs/smac/v1/5m_vs_6m_torch.yaml diff --git a/rl_games/configs/smac/6h_vs_8z_torch.yaml b/rl_games/configs/smac/v1/6h_vs_8z_torch.yaml similarity index 100% rename from rl_games/configs/smac/6h_vs_8z_torch.yaml rename to rl_games/configs/smac/v1/6h_vs_8z_torch.yaml diff --git a/rl_games/configs/smac/6h_vs_8z_torch_cv.yaml b/rl_games/configs/smac/v1/6h_vs_8z_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/6h_vs_8z_torch_cv.yaml rename to rl_games/configs/smac/v1/6h_vs_8z_torch_cv.yaml diff --git a/rl_games/configs/smac/8m_torch.yaml b/rl_games/configs/smac/v1/8m_torch.yaml similarity index 100% rename from rl_games/configs/smac/8m_torch.yaml rename to rl_games/configs/smac/v1/8m_torch.yaml diff --git a/rl_games/configs/smac/8m_torch_cv.yaml b/rl_games/configs/smac/v1/8m_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/8m_torch_cv.yaml rename to rl_games/configs/smac/v1/8m_torch_cv.yaml diff --git a/rl_games/configs/smac/MMM2_torch.yaml b/rl_games/configs/smac/v1/MMM2_torch.yaml similarity index 100% rename from rl_games/configs/smac/MMM2_torch.yaml rename to rl_games/configs/smac/v1/MMM2_torch.yaml diff --git a/rl_games/configs/smac/corridor_torch.yaml b/rl_games/configs/smac/v1/corridor_torch.yaml similarity index 100% rename from rl_games/configs/smac/corridor_torch.yaml rename to rl_games/configs/smac/v1/corridor_torch.yaml diff --git a/rl_games/configs/smac/corridor_torch_cv.yaml b/rl_games/configs/smac/v1/corridor_torch_cv.yaml similarity index 100% rename from rl_games/configs/smac/corridor_torch_cv.yaml rename to rl_games/configs/smac/v1/corridor_torch_cv.yaml diff --git a/rl_games/configs/smac/runs/2c_vs_64zg.yaml b/rl_games/configs/smac/v1/runs/2c_vs_64zg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2c_vs_64zg.yaml rename to rl_games/configs/smac/v1/runs/2c_vs_64zg.yaml diff --git a/rl_games/configs/smac/runs/2c_vs_64zg_neg.yaml b/rl_games/configs/smac/v1/runs/2c_vs_64zg_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2c_vs_64zg_neg.yaml rename to rl_games/configs/smac/v1/runs/2c_vs_64zg_neg.yaml diff --git a/rl_games/configs/smac/runs/2s3z.yaml b/rl_games/configs/smac/v1/runs/2s3z.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s3z.yaml rename to rl_games/configs/smac/v1/runs/2s3z.yaml diff --git a/rl_games/configs/smac/runs/2s3z_neg.yaml b/rl_games/configs/smac/v1/runs/2s3z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s3z_neg.yaml rename to rl_games/configs/smac/v1/runs/2s3z_neg.yaml diff --git a/rl_games/configs/smac/runs/2s_vs_1c.yaml b/rl_games/configs/smac/v1/runs/2s_vs_1c.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s_vs_1c.yaml rename to rl_games/configs/smac/v1/runs/2s_vs_1c.yaml diff --git a/rl_games/configs/smac/runs/2s_vs_1c_neg.yaml b/rl_games/configs/smac/v1/runs/2s_vs_1c_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/2s_vs_1c_neg.yaml rename to rl_games/configs/smac/v1/runs/2s_vs_1c_neg.yaml diff --git a/rl_games/configs/smac/runs/3s5z.yaml b/rl_games/configs/smac/v1/runs/3s5z.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s5z.yaml rename to rl_games/configs/smac/v1/runs/3s5z.yaml diff --git a/rl_games/configs/smac/runs/3s5z_neg.yaml b/rl_games/configs/smac/v1/runs/3s5z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s5z_neg.yaml rename to rl_games/configs/smac/v1/runs/3s5z_neg.yaml diff --git a/rl_games/configs/smac/runs/3s_vs_5z.yaml b/rl_games/configs/smac/v1/runs/3s_vs_5z.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s_vs_5z.yaml rename to rl_games/configs/smac/v1/runs/3s_vs_5z.yaml diff --git a/rl_games/configs/smac/runs/3s_vs_5z_neg.yaml b/rl_games/configs/smac/v1/runs/3s_vs_5z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s_vs_5z_neg.yaml rename to rl_games/configs/smac/v1/runs/3s_vs_5z_neg.yaml diff --git a/rl_games/configs/smac/runs/3s_vs_5z_neg_joint.yaml b/rl_games/configs/smac/v1/runs/3s_vs_5z_neg_joint.yaml similarity index 100% rename from rl_games/configs/smac/runs/3s_vs_5z_neg_joint.yaml rename to rl_games/configs/smac/v1/runs/3s_vs_5z_neg_joint.yaml diff --git a/rl_games/configs/smac/runs/6h_vs_8z.yaml b/rl_games/configs/smac/v1/runs/6h_vs_8z.yaml similarity index 100% rename from rl_games/configs/smac/runs/6h_vs_8z.yaml rename to rl_games/configs/smac/v1/runs/6h_vs_8z.yaml diff --git a/rl_games/configs/smac/runs/6h_vs_8z_neg.yaml b/rl_games/configs/smac/v1/runs/6h_vs_8z_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/6h_vs_8z_neg.yaml rename to rl_games/configs/smac/v1/runs/6h_vs_8z_neg.yaml diff --git a/rl_games/configs/smac/runs/6h_vs_8z_rnn.yaml b/rl_games/configs/smac/v1/runs/6h_vs_8z_rnn.yaml similarity index 100% rename from rl_games/configs/smac/runs/6h_vs_8z_rnn.yaml rename to rl_games/configs/smac/v1/runs/6h_vs_8z_rnn.yaml diff --git a/rl_games/configs/smac/runs/MMM2.yaml b/rl_games/configs/smac/v1/runs/MMM2.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2.yaml rename to rl_games/configs/smac/v1/runs/MMM2.yaml diff --git a/rl_games/configs/smac/runs/MMM2_conv1d.yaml b/rl_games/configs/smac/v1/runs/MMM2_conv1d.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2_conv1d.yaml rename to rl_games/configs/smac/v1/runs/MMM2_conv1d.yaml diff --git a/rl_games/configs/smac/runs/MMM2_neg.yaml b/rl_games/configs/smac/v1/runs/MMM2_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2_neg.yaml rename to rl_games/configs/smac/v1/runs/MMM2_neg.yaml diff --git a/rl_games/configs/smac/runs/MMM2_rnn.yaml b/rl_games/configs/smac/v1/runs/MMM2_rnn.yaml similarity index 100% rename from rl_games/configs/smac/runs/MMM2_rnn.yaml rename to rl_games/configs/smac/v1/runs/MMM2_rnn.yaml diff --git a/rl_games/configs/smac/runs/bane_vs_bane.yaml b/rl_games/configs/smac/v1/runs/bane_vs_bane.yaml similarity index 100% rename from rl_games/configs/smac/runs/bane_vs_bane.yaml rename to rl_games/configs/smac/v1/runs/bane_vs_bane.yaml diff --git a/rl_games/configs/smac/runs/bane_vs_bane_neg.yaml b/rl_games/configs/smac/v1/runs/bane_vs_bane_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/bane_vs_bane_neg.yaml rename to rl_games/configs/smac/v1/runs/bane_vs_bane_neg.yaml diff --git a/rl_games/configs/smac/runs/corridor_cv.yaml b/rl_games/configs/smac/v1/runs/corridor_cv.yaml similarity index 100% rename from rl_games/configs/smac/runs/corridor_cv.yaml rename to rl_games/configs/smac/v1/runs/corridor_cv.yaml diff --git a/rl_games/configs/smac/runs/corridor_cv_neg.yaml b/rl_games/configs/smac/v1/runs/corridor_cv_neg.yaml similarity index 100% rename from rl_games/configs/smac/runs/corridor_cv_neg.yaml rename to rl_games/configs/smac/v1/runs/corridor_cv_neg.yaml diff --git a/rl_games/configs/smac/v2/env_configs/sc2_gen_protoss.yaml b/rl_games/configs/smac/v2/env_configs/sc2_gen_protoss.yaml new file mode 100644 index 00000000..ddfc7b21 --- /dev/null +++ b/rl_games/configs/smac/v2/env_configs/sc2_gen_protoss.yaml @@ -0,0 +1,69 @@ +env: sc2wrapped + +env_args: + continuing_episode: False + difficulty: "7" + game_version: null + map_name: "10gen_protoss" + move_amount: 2 + obs_all_health: True + obs_instead_of_state: False + obs_last_action: False + obs_own_health: True + obs_pathing_grid: False + obs_terrain_height: False + obs_timestep_number: False + reward_death_value: 10 + reward_defeat: 0 + reward_negative_scale: 0.5 + reward_only_positive: True + reward_scale: True + reward_scale_rate: 20 + reward_sparse: False + reward_win: 200 + replay_dir: "" + replay_prefix: "" + conic_fov: False + use_unit_ranges: True + min_attack_range: 2 + obs_own_pos: True + num_fov_actions: 12 + capability_config: + n_units: 5 + n_enemies: 5 + team_gen: + dist_type: "weighted_teams" + unit_types: + - "stalker" + - "zealot" + - "colossus" + weights: + - 0.45 + - 0.45 + - 0.1 + observe: True + start_positions: + dist_type: "surrounded_and_reflect" + p: 0.5 + map_x: 32 + map_y: 32 + + # enemy_mask: + # dist_type: "mask" + # mask_probability: 0.5 + # n_enemies: 5 + state_last_action: True + state_timestep_number: False + step_mul: 8 + heuristic_ai: False + # heuristic_rest: False + debug: False + prob_obs_enemy: 1.0 + action_mask: True + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 10050000 diff --git a/rl_games/configs/smac/v2/env_configs/sc2_gen_protoss_epo.yaml b/rl_games/configs/smac/v2/env_configs/sc2_gen_protoss_epo.yaml new file mode 100644 index 00000000..be958172 --- /dev/null +++ b/rl_games/configs/smac/v2/env_configs/sc2_gen_protoss_epo.yaml @@ -0,0 +1,70 @@ +env: sc2wrapped + +env_args: + continuing_episode: False + difficulty: "7" + game_version: null + map_name: "10gen_protoss" + move_amount: 2 + obs_all_health: True + obs_instead_of_state: False + obs_last_action: False + obs_own_health: True + obs_pathing_grid: False + obs_terrain_height: False + obs_timestep_number: False + reward_death_value: 10 + reward_defeat: 0 + reward_negative_scale: 0.5 + reward_only_positive: True + reward_scale: True + reward_scale_rate: 20 + reward_sparse: False + reward_win: 200 + replay_dir: "" + replay_prefix: "" + conic_fov: False + use_unit_ranges: True + min_attack_range: 2 + obs_own_pos: True + num_fov_actions: 12 + capability_config: + n_units: 5 + n_enemies: 5 + team_gen: + dist_type: "weighted_teams" + unit_types: + - "stalker" + - "zealot" + - "colossus" + weights: + - 0.45 + - 0.45 + - 0.1 + observe: True + start_positions: + dist_type: "surrounded_and_reflect" + p: 0.5 + map_x: 32 + map_y: 32 + + # enemy_mask: + # dist_type: "mask" + # mask_probability: 0.5 + # n_enemies: 5 + state_last_action: True + state_timestep_number: False + step_mul: 8 + heuristic_ai: False + # heuristic_rest: False + debug: False + # Most severe partial obs setting: + prob_obs_enemy: 0.0 + action_mask: False + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 10050000 \ No newline at end of file diff --git a/rl_games/configs/smac/v2/env_configs/sc2_gen_terran.yaml b/rl_games/configs/smac/v2/env_configs/sc2_gen_terran.yaml new file mode 100644 index 00000000..50ecb69c --- /dev/null +++ b/rl_games/configs/smac/v2/env_configs/sc2_gen_terran.yaml @@ -0,0 +1,71 @@ +env: sc2wrapped + +env_args: + continuing_episode: False + difficulty: "7" + game_version: null + map_name: "10gen_terran" + move_amount: 2 + obs_all_health: True + obs_instead_of_state: False + obs_last_action: False + obs_own_health: True + obs_pathing_grid: False + obs_terrain_height: False + obs_timestep_number: False + reward_death_value: 10 + reward_defeat: 0 + reward_negative_scale: 0.5 + reward_only_positive: True + reward_scale: True + reward_scale_rate: 20 + reward_sparse: False + reward_win: 200 + replay_dir: "" + replay_prefix: "" + conic_fov: False + obs_own_pos: True + use_unit_ranges: True + min_attack_range: 2 + num_fov_actions: 12 + capability_config: + n_units: 5 + n_enemies: 5 + team_gen: + dist_type: "weighted_teams" + unit_types: + - "marine" + - "marauder" + - "medivac" + weights: + - 0.45 + - 0.45 + - 0.1 + exception_unit_types: + - "medivac" + observe: True + + start_positions: + dist_type: "surrounded_and_reflect" + p: 0.5 + map_x: 32 + map_y: 32 + # enemy_mask: + # dist_type: "mask" + # mask_probability: 0.5 + # n_enemies: 5 + state_last_action: True + state_timestep_number: False + step_mul: 8 + heuristic_ai: False + # heuristic_rest: False + debug: False + prob_obs_enemy: 1.0 + action_mask: True + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 10050000 diff --git a/rl_games/configs/smac/v2/env_configs/sc2_gen_terran_epo.yaml b/rl_games/configs/smac/v2/env_configs/sc2_gen_terran_epo.yaml new file mode 100644 index 00000000..02bca388 --- /dev/null +++ b/rl_games/configs/smac/v2/env_configs/sc2_gen_terran_epo.yaml @@ -0,0 +1,72 @@ +env: sc2wrapped + +env_args: + continuing_episode: False + difficulty: "7" + game_version: null + map_name: "10gen_terran" + move_amount: 2 + obs_all_health: True + obs_instead_of_state: False + obs_last_action: False + obs_own_health: True + obs_pathing_grid: False + obs_terrain_height: False + obs_timestep_number: False + reward_death_value: 10 + reward_defeat: 0 + reward_negative_scale: 0.5 + reward_only_positive: True + reward_scale: True + reward_scale_rate: 20 + reward_sparse: False + reward_win: 200 + replay_dir: "" + replay_prefix: "" + conic_fov: False + obs_own_pos: True + use_unit_ranges: True + min_attack_range: 2 + num_fov_actions: 12 + capability_config: + n_units: 5 + n_enemies: 5 + team_gen: + dist_type: "weighted_teams" + unit_types: + - "marine" + - "marauder" + - "medivac" + weights: + - 0.45 + - 0.45 + - 0.1 + exception_unit_types: + - "medivac" + observe: True + + start_positions: + dist_type: "surrounded_and_reflect" + p: 0.5 + map_x: 32 + map_y: 32 + # enemy_mask: + # dist_type: "mask" + # mask_probability: 0.5 + # n_enemies: 5 + state_last_action: True + state_timestep_number: False + step_mul: 8 + heuristic_ai: False + # heuristic_rest: False + debug: False + # Most severe partial obs setting: + prob_obs_enemy: 0.0 + action_mask: False + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 10050000 diff --git a/rl_games/configs/smac/v2/env_configs/sc2_gen_zerg.yaml b/rl_games/configs/smac/v2/env_configs/sc2_gen_zerg.yaml new file mode 100644 index 00000000..f13c0707 --- /dev/null +++ b/rl_games/configs/smac/v2/env_configs/sc2_gen_zerg.yaml @@ -0,0 +1,71 @@ +env: sc2wrapped + +env_args: + continuing_episode: False + difficulty: "7" + game_version: null + map_name: "10gen_zerg" + move_amount: 2 + obs_all_health: True + obs_instead_of_state: False + obs_last_action: False + obs_own_health: True + obs_pathing_grid: False + obs_terrain_height: False + obs_timestep_number: False + reward_death_value: 10 + reward_defeat: 0 + reward_negative_scale: 0.5 + reward_only_positive: True + reward_scale: True + reward_scale_rate: 20 + reward_sparse: False + reward_win: 200 + replay_dir: "" + replay_prefix: "" + conic_fov: False + use_unit_ranges: True + min_attack_range: 2 + num_fov_actions: 12 + obs_own_pos: True + capability_config: + n_units: 5 + n_enemies: 5 + team_gen: + dist_type: "weighted_teams" + unit_types: + - "zergling" + - "baneling" + - "hydralisk" + weights: + - 0.45 + - 0.1 + - 0.45 + exception_unit_types: + - "baneling" + observe: True + + start_positions: + dist_type: "surrounded_and_reflect" + p: 0.5 + map_x: 32 + map_y: 32 + # enemy_mask: + # dist_type: "mask" + # mask_probability: 0.5 + # n_enemies: 5 + state_last_action: True + state_timestep_number: False + step_mul: 8 + heuristic_ai: False + # heuristic_rest: False + debug: False + prob_obs_enemy: 1.0 + action_mask: True + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 10050000 diff --git a/rl_games/configs/smac/v2/env_configs/sc2_gen_zerg_epo.yaml b/rl_games/configs/smac/v2/env_configs/sc2_gen_zerg_epo.yaml new file mode 100644 index 00000000..d75034f5 --- /dev/null +++ b/rl_games/configs/smac/v2/env_configs/sc2_gen_zerg_epo.yaml @@ -0,0 +1,72 @@ +env: sc2wrapped + +env_args: + continuing_episode: False + difficulty: "7" + game_version: null + map_name: "10gen_zerg" + move_amount: 2 + obs_all_health: True + obs_instead_of_state: False + obs_last_action: False + obs_own_health: True + obs_pathing_grid: False + obs_terrain_height: False + obs_timestep_number: False + reward_death_value: 10 + reward_defeat: 0 + reward_negative_scale: 0.5 + reward_only_positive: True + reward_scale: True + reward_scale_rate: 20 + reward_sparse: False + reward_win: 200 + replay_dir: "" + replay_prefix: "" + conic_fov: False + use_unit_ranges: True + min_attack_range: 2 + num_fov_actions: 12 + obs_own_pos: True + capability_config: + n_units: 5 + n_enemies: 5 + team_gen: + dist_type: "weighted_teams" + unit_types: + - "zergling" + - "baneling" + - "hydralisk" + weights: + - 0.45 + - 0.1 + - 0.45 + exception_unit_types: + - "baneling" + observe: True + + start_positions: + dist_type: "surrounded_and_reflect" + p: 0.5 + map_x: 32 + map_y: 32 + # enemy_mask: + # dist_type: "mask" + # mask_probability: 0.5 + # n_enemies: 5 + state_last_action: True + state_timestep_number: False + step_mul: 8 + heuristic_ai: False + # heuristic_rest: False + debug: False + # most severe partial obs setting: + prob_obs_enemy: 0.0 + action_mask: False + +test_nepisode: 32 +test_interval: 10000 +log_interval: 2000 +runner_log_interval: 2000 +learner_log_interval: 2000 +t_max: 10050000 \ No newline at end of file diff --git a/rl_games/configs/smac/v2/protos_5_v_5.yaml b/rl_games/configs/smac/v2/protos_5_v_5.yaml new file mode 100644 index 00000000..1a853c70 --- /dev/null +++ b/rl_games/configs/smac/v2/protos_5_v_5.yaml @@ -0,0 +1,86 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: actor_critic + separate: False + #normalization: layer_norm + space: + discrete: + + mlp: + units: [512, 256] + activation: relu + initializer: + name: default + rnn: + name: lstm + units: 128 + layers: 1 + layer_norm: False + + config: + name: protos_5_v_5 + reward_shaper: + scale_value: 1 + + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-4 + score_to_win: 20 + entropy_coef: 0.005 + truncate_grads: True + grad_norm: 10 + env_name: smac_v2 + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 256 + minibatch_size: 2560 # 5 * 512 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: linear + kl_threshold: 0.05 + normalize_input: True + normalize_value: True + use_action_masks: True + max_epochs: 4000 + seq_length: 16 + + player: + games_num: 200 + env_config: + name: 'COULD_BE_IGNORED' + path: 'rl_games/configs/smac/v2/env_configs/sc2_gen_protoss.yaml' + frames: 1 + transpose: False + random_invalid_step: False + central_value: True + apply_agent_ids: True + + central_value_config: + minibatch_size: 512 + mini_epochs: 4 + learning_rate: 5e-4 + clip_value: True + normalize_input: True + network: + name: actor_critic + central_value: True + mlp: + units: [512, 256] + activation: relu + initializer: + name: default + regularizer: + name: None + rnn: + name: lstm + units: 128 + layers: 1 + layer_norm: False \ No newline at end of file diff --git a/rl_games/configs/smac/v2/terran_5_v_5.yaml b/rl_games/configs/smac/v2/terran_5_v_5.yaml new file mode 100644 index 00000000..472e549e --- /dev/null +++ b/rl_games/configs/smac/v2/terran_5_v_5.yaml @@ -0,0 +1,86 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: actor_critic + separate: False + #normalization: layer_norm + space: + discrete: + + mlp: + units: [512, 256] + activation: relu + initializer: + name: default + rnn: + name: lstm + units: 128 + layers: 1 + layer_norm: False + + config: + name: terran_5_v_5 + reward_shaper: + scale_value: 1 + + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-4 + score_to_win: 20 + entropy_coef: 0.005 + truncate_grads: True + grad_norm: 10 + env_name: smac_v2 + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 256 + minibatch_size: 2560 # 5 * 512 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: linear + kl_threshold: 0.05 + normalize_input: True + normalize_value: True + use_action_masks: True + max_epochs: 4000 + seq_length: 16 + + player: + games_num: 200 + env_config: + name: 'COULD_BE_IGNORED' + path: 'rl_games/configs/smac/v2/env_configs/sc2_gen_terran.yaml' + frames: 1 + transpose: False + random_invalid_step: False + central_value: True + apply_agent_ids: True + + central_value_config: + minibatch_size: 512 + mini_epochs: 4 + learning_rate: 5e-4 + clip_value: True + normalize_input: True + network: + name: actor_critic + central_value: True + mlp: + units: [512, 256] + activation: relu + initializer: + name: default + regularizer: + name: None + rnn: + name: lstm + units: 128 + layers: 1 + layer_norm: False \ No newline at end of file diff --git a/rl_games/configs/smac/v2/zerg_5_v_5.yaml b/rl_games/configs/smac/v2/zerg_5_v_5.yaml new file mode 100644 index 00000000..bd54d46d --- /dev/null +++ b/rl_games/configs/smac/v2/zerg_5_v_5.yaml @@ -0,0 +1,86 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: actor_critic + separate: False + #normalization: layer_norm + space: + discrete: + + mlp: + units: [512, 256] + activation: relu + initializer: + name: default + rnn: + name: lstm + units: 128 + layers: 1 + layer_norm: False + + config: + name: zerg_5_v_5 + reward_shaper: + scale_value: 1 + + normalize_advantage: True + gamma: 0.99 + tau: 0.95 + learning_rate: 2e-4 + score_to_win: 20 + entropy_coef: 0.005 + truncate_grads: True + grad_norm: 10 + env_name: smac_v2 + e_clip: 0.2 + clip_value: False + num_actors: 16 + horizon_length: 256 + minibatch_size: 2560 # 5 * 512 + mini_epochs: 4 + critic_coef: 1 + lr_schedule: linear + kl_threshold: 0.05 + normalize_input: True + normalize_value: True + use_action_masks: True + max_epochs: 4000 + seq_length: 16 + + player: + games_num: 200 + env_config: + name: 'COULD_BE_IGNORED' + path: 'rl_games/configs/smac/v2/env_configs/sc2_gen_zerg.yaml' + frames: 1 + transpose: False + random_invalid_step: False + central_value: True + apply_agent_ids: True + + central_value_config: + minibatch_size: 512 + mini_epochs: 4 + learning_rate: 5e-4 + clip_value: True + normalize_input: True + network: + name: actor_critic + central_value: True + mlp: + units: [512, 256] + activation: relu + initializer: + name: default + regularizer: + name: None + rnn: + name: lstm + units: 128 + layers: 1 + layer_norm: False \ No newline at end of file diff --git a/rl_games/distributed/__init__.py b/rl_games/distributed/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/rl_games/envs/__init__.py b/rl_games/envs/__init__.py index 69a343a6..6883b34a 100644 --- a/rl_games/envs/__init__.py +++ b/rl_games/envs/__init__.py @@ -1,8 +1,6 @@ -from rl_games.envs.connect4_network import ConnectBuilder from rl_games.envs.test_network import TestNetBuilder from rl_games.algos_torch import model_builder -model_builder.register_network('connect4net', ConnectBuilder) model_builder.register_network('testnet', TestNetBuilder) \ No newline at end of file diff --git a/rl_games/envs/connect4_network.py b/rl_games/envs/connect4_network.py deleted file mode 100644 index 4ef52118..00000000 --- a/rl_games/envs/connect4_network.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -class ConvBlock(nn.Module): - def __init__(self): - super(ConvBlock, self).__init__() - self.action_size = 7 - self.conv1 = nn.Conv2d(4, 128, 3, stride=1, padding=1) - self.bn1 = nn.BatchNorm2d(128) - - def forward(self, s): - s = s['obs'].contiguous() - #s = s.view(-1, 3, 6, 7) # batch_size x channels x board_x x board_y - s = F.relu(self.bn1(self.conv1(s))) - return s - - - -class ResBlock(nn.Module): - def __init__(self, inplanes=128, planes=128, stride=1, downsample=None): - super(ResBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - def forward(self, x): - residual = x - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += residual - out = F.relu(out) - return out - - - -class OutBlock(nn.Module): - def __init__(self): - super(OutBlock, self).__init__() - self.conv = nn.Conv2d(128, 3, kernel_size=1) # value head - self.bn = nn.BatchNorm2d(3) - self.fc1 = nn.Linear(3*6*7, 32) - self.fc2 = nn.Linear(32, 1) - - self.conv1 = nn.Conv2d(128, 32, kernel_size=1) # policy head - self.bn1 = nn.BatchNorm2d(32) - self.fc = nn.Linear(6*7*32, 7) - - def forward(self,s): - v = F.relu(self.bn(self.conv(s))) # value head - v = v.view(-1, 3*6*7) # batch_size X channel X height X width - v = F.relu(self.fc1(v)) - v = F.relu(self.fc2(v)) - v = torch.tanh(v) - - p = F.relu(self.bn1(self.conv1(s))) # policy head - p = p.view(-1, 6*7*32) - p = self.fc(p) - return p, v, None - -class ConnectNet(nn.Module): - def __init__(self, blocks): - super(ConnectNet, self).__init__() - self.blocks = blocks - self.conv = ConvBlock() - for block in range(self.blocks): - setattr(self, "res_%i" % block,ResBlock()) - self.outblock = OutBlock() - def is_rnn(self): - return False - - def forward(self,s): - s = s.permute((0, 3, 1, 2)) - s = self.conv(s) - for block in range(self.blocks): - s = getattr(self, "res_%i" % block)(s) - s = self.outblock(s) - return s - - -from rl_games.algos_torch.network_builder import NetworkBuilder - -class ConnectBuilder(NetworkBuilder): - def __init__(self, **kwargs): - NetworkBuilder.__init__(self) - - def load(self, params): - self.params = params - self.blocks = params['blocks'] - - def build(self, name, **kwargs): - return ConnectNet(self.blocks) - - def __call__(self, name, **kwargs): - return self.build(name, **kwargs) - diff --git a/rl_games/envs/connect4_selfplay.py b/rl_games/envs/connect4_selfplay.py deleted file mode 100644 index 913e1e2a..00000000 --- a/rl_games/envs/connect4_selfplay.py +++ /dev/null @@ -1,132 +0,0 @@ -import gym -import numpy as np -from pettingzoo.classic import connect_four_v0 -import yaml -from rl_games.torch_runner import Runner -import os -from collections import deque - - -class ConnectFourSelfPlay(gym.Env): - def __init__(self, name="connect_four_v0", **kwargs): - gym.Env.__init__(self) - self.name = name - self.is_deterministic = kwargs.pop('is_deterministic', False) - self.is_human = kwargs.pop('is_human', False) - self.random_agent = kwargs.pop('random_agent', False) - self.config_path = kwargs.pop('config_path') - self.agent = None - - self.env = connect_four_v0.env() # gym.make(name, **kwargs) - self.action_space = self.env.action_spaces['player_0'] - observation_space = self.env.observation_spaces['player_0'] - shp = observation_space.shape - self.observation_space = gym.spaces.Box( - low=0, high=1, shape=(shp[:-1] + (shp[-1] * 2,)), dtype=np.uint8) - self.obs_deque = deque([], maxlen=2) - self.agent_id = 0 - - def _get_legal_moves(self, agent_id): - name = 'player_0' if agent_id == 0 else 'player_1' - action_ids = self.env.infos[name]['legal_moves'] - mask = np.zeros(self.action_space.n, dtype=np.bool) - mask[action_ids] = True - return mask, action_ids - - def env_step(self, action): - obs = self.env.step(action) - info = {} - name = 'player_0' if self.agent_id == 0 else 'player_1' - reward = self.env.rewards[name] - done = self.env.dones[name] - return obs, reward, done, info - - def get_obs(self): - return np.concatenate(self.obs_deque, -1).astype(np.uint8) * 255 - - def reset(self): - if self.agent == None: - self.create_agent(self.config_path) - - self.agent_id = np.random.randint(2) - obs = self.env.reset() - self.obs_deque.append(obs) - self.obs_deque.append(obs) - if self.agent_id == 1: - op_obs = self.get_obs() - op_obs = self.agent.obs_to_torch(op_obs) - mask, ids = self._get_legal_moves(0) - if self.is_human: - self.render() - opponent_action = int(input()) - else: - if self.random_agent: - opponent_action = np.random.choice(ids, 1)[0] - else: - opponent_action = self.agent.get_masked_action( - op_obs, mask, self.is_deterministic).item() - - obs, _, _, _ = self.env_step(opponent_action) - - self.obs_deque.append(obs) - return self.get_obs() - - def create_agent(self, config): - with open(config, 'r') as stream: - config = yaml.safe_load(stream) - runner = Runner() - runner.load(config) - config = runner.get_prebuilt_config() - # 'RAYLIB has bug here, CUDA_VISIBLE_DEVICES become unset' - if 'CUDA_VISIBLE_DEVICES' in os.environ: - os.environ.pop('CUDA_VISIBLE_DEVICES') - - self.agent = runner.create_player() - self.agent.model.eval() - - def step(self, action): - - obs, reward, done, info = self.env_step(action) - self.obs_deque.append(obs) - - if done: - if reward == 1: - info['battle_won'] = 1 - else: - info['battle_won'] = 0 - return self.get_obs(), reward, done, info - - op_obs = self.get_obs() - - op_obs = self.agent.obs_to_torch(op_obs) - mask, ids = self._get_legal_moves(1-self.agent_id) - if self.is_human: - self.render() - opponent_action = int(input()) - else: - if self.random_agent: - opponent_action = np.random.choice(ids, 1)[0] - else: - opponent_action = self.agent.get_masked_action( - op_obs, mask, self.is_deterministic).item() - obs, reward, done, _ = self.env_step(opponent_action) - if done: - if reward == -1: - info['battle_won'] = 0 - else: - info['battle_won'] = 1 - self.obs_deque.append(obs) - return self.get_obs(), reward, done, info - - def render(self, mode='ansi'): - self.env.render(mode) - - def update_weights(self, weigths): - self.agent.set_weights(weigths) - - def get_action_mask(self): - mask, _ = self._get_legal_moves(self.agent_id) - return mask - - def has_action_mask(self): - return True diff --git a/rl_games/envs/multiwalker.py b/rl_games/envs/multiwalker.py index de3ab71a..f421bd61 100644 --- a/rl_games/envs/multiwalker.py +++ b/rl_games/envs/multiwalker.py @@ -5,7 +5,6 @@ from rl_games.torch_runner import Runner import os from collections import deque -import rl_games.envs.connect4_network class MultiWalker(gym.Env): def __init__(self, name="multiwalker", **kwargs): diff --git a/rl_games/envs/smac_env.py b/rl_games/envs/smac_env.py index a4039e55..79695743 100644 --- a/rl_games/envs/smac_env.py +++ b/rl_games/envs/smac_env.py @@ -96,7 +96,7 @@ def step(self, actions): return obses, rewards, dones, info def get_action_mask(self): - return np.array(self.env.get_avail_actions(), dtype=np.bool) + return np.array(self.env.get_avail_actions(), dtype=bool) def has_action_mask(self): return not self.random_invalid_step diff --git a/rl_games/envs/smac_v2_env.py b/rl_games/envs/smac_v2_env.py new file mode 100644 index 00000000..1612bbe5 --- /dev/null +++ b/rl_games/envs/smac_v2_env.py @@ -0,0 +1,112 @@ +import gym +import numpy as np +import yaml +from smacv2.env import StarCraft2Env +from smacv2.env import MultiAgentEnv +from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper + +class SMACEnvV2(gym.Env): + def __init__(self, name="3m", **kwargs): + gym.Env.__init__(self) + self._seed = kwargs.pop('seed', None) + self.path = kwargs.pop('path') + self.reward_sparse = kwargs.get('reward_sparse', False) + self.use_central_value = kwargs.pop('central_value', True) + self.concat_infos = True + self.random_invalid_step = kwargs.pop('random_invalid_step', False) + self.replay_save_freq = kwargs.pop('replay_save_freq', 10000) + self.apply_agent_ids = kwargs.pop('apply_agent_ids', True) + with open(self.path, 'r') as stream: + config = yaml.safe_load(stream) + env_args = config['env_args'] + self.env = StarCraftCapabilityEnvWrapper(seed=self._seed, **env_args) + self.env_info = self.env.get_env_info() + + self._game_num = 0 + self.n_actions = self.env_info["n_actions"] + self.n_agents = self.env_info["n_agents"] + self.action_space = gym.spaces.Discrete(self.n_actions) + one_hot_agents = 0 + + if self.apply_agent_ids: + one_hot_agents = self.n_agents + self.observation_space = gym.spaces.Box(low=0, high=1, shape=(self.env_info['obs_shape']+one_hot_agents, ), dtype=np.float32) + self.state_space = gym.spaces.Box(low=0, high=1, shape=(self.env_info['state_shape'], ), dtype=np.float32) + + self.obs_dict = {} + + def _preproc_state_obs(self, state, obs): + # todo: remove from self + if self.apply_agent_ids: + num_agents = self.n_agents + obs = np.array(obs) + all_ids = np.eye(num_agents, dtype=np.float32) + obs = np.concatenate([obs, all_ids], axis=-1) + + self.obs_dict["obs"] = np.array(obs) + self.obs_dict["state"] = np.array(state) + + if self.use_central_value: + return self.obs_dict + else: + return self.obs_dict["obs"] + + def get_number_of_agents(self): + return self.n_agents + + def reset(self): + if self._game_num % self.replay_save_freq == 1: + print('saving replay') + self.env.save_replay() + self._game_num += 1 + obs, state = self.env.reset() # rename, to think remove + obs_dict = self._preproc_state_obs(state, obs) + + return obs_dict + + def _preproc_actions(self, actions): + actions = actions.copy() + rewards = np.zeros_like(actions) + mask = self.get_action_mask() + for ind, action in enumerate(actions, start=0): + avail_actions = np.nonzero(mask[ind])[0] + if action not in avail_actions: + actions[ind] = np.random.choice(avail_actions) + #rewards[ind] = -0.05 + return actions, rewards + + def step(self, actions): + fixed_rewards = None + + if self.random_invalid_step: + actions, fixed_rewards = self._preproc_actions(actions) + + reward, done, info = self.env.step(actions) + time_out = self.env._episode_steps >= self.env.episode_limit + info['time_outs'] = [time_out]*self.n_agents + + if done: + battle_won = info.get('battle_won', False) + if not battle_won and self.reward_sparse: + reward = -1.0 + + obs = self.env.get_obs() + state = self.env.get_state() + obses = self._preproc_state_obs(state, obs) + rewards = np.repeat (reward, self.n_agents) + dones = np.repeat (done, self.n_agents) + + if fixed_rewards is not None: + rewards += fixed_rewards + + return obses, rewards, dones, info + + def get_action_mask(self): + return np.array(self.env.get_avail_actions(), dtype=bool) + + def has_action_mask(self): + return not self.random_invalid_step + + def seed(self, _): + pass +