Skip to content

Commit

Permalink
Fixed masks with multi discrete space. (#265)
Browse files Browse the repository at this point in the history
Co-authored-by: Denys Makoviichuk <[email protected]>
  • Loading branch information
Denys88 and DenSumy authored Dec 1, 2023
1 parent a5d788a commit a79fcc2
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 9 deletions.
2 changes: 0 additions & 2 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def get_masked_action_values(self, obs, action_masks):
value = self.get_central_value(input_dict)
res_dict['values'] = value

if self.is_multi_discrete:
action_masks = torch.cat(action_masks, dim=-1)
res_dict['action_masks'] = action_masks
return res_dict

Expand Down
6 changes: 4 additions & 2 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def forward(self, input_dict):
if is_train:
if action_masks is None:
categorical = [Categorical(logits=logit) for logit in logits]
else:
else:
action_masks = np.split(action_masks,len(logits), axis=1)
categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)]
prev_actions = torch.split(prev_actions, 1, dim=-1)
prev_neglogp = [-c.log_prob(a.squeeze()) for c,a in zip(categorical, prev_actions)]
Expand All @@ -162,7 +163,8 @@ def forward(self, input_dict):
else:
if action_masks is None:
categorical = [Categorical(logits=logit) for logit in logits]
else:
else:
action_masks = np.split(action_masks, len(logits), axis=1)
categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)]

selected_action = [c.sample().long() for c in categorical]
Expand Down
12 changes: 9 additions & 3 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,12 @@ def create_roboschool_env(name):
return gym.make(name)

def create_smac(name, **kwargs):
from rl_games.envs.smac_env import SMACEnv
from rl_games.envs.smac_env import SMACEnv, MultiDiscreteSmacWrapper
frames = kwargs.pop('frames', 1)
transpose = kwargs.pop('transpose', False)
flatten = kwargs.pop('flatten', True)
has_cv = kwargs.get('central_value', False)
as_single_agent = kwargs.pop('as_single_agent', False)
env = SMACEnv(name, **kwargs)


Expand All @@ -166,6 +167,9 @@ def create_smac(name, **kwargs):
env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten)
else:
env = wrappers.BatchedFrameStack(env, frames, transpose=False, flatten=flatten)

if as_single_agent:
env = MultiDiscreteSmacWrapper(env)
return env

def create_smac_v2(name, **kwargs):
Expand All @@ -184,16 +188,18 @@ def create_smac_v2(name, **kwargs):
return env

def create_smac_cnn(name, **kwargs):
from rl_games.envs.smac_env import SMACEnv
from rl_games.envs.smac_env import SMACEnv, MultiDiscreteSmacWrapper
has_cv = kwargs.get('central_value', False)
frames = kwargs.pop('frames', 4)
transpose = kwargs.pop('transpose', False)

env = SMACEnv(name, **kwargs)
if has_cv:
env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=transpose)
else:
env = wrappers.BatchedFrameStack(env, frames, transpose=transpose)

if as_single_agent:
env = MultiDiscreteSmacWrapper(env)
return env

def create_test_env(name, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _init_from_env_info(self, env_info):
if self.is_discrete or self.is_multi_discrete:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=int), obs_base_shape)
if self.use_action_masks:
self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=bool), obs_base_shape)
self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=(np.sum(self.actions_num),), dtype=bool), obs_base_shape)
if self.is_continuous:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape)
self.tensor_dict['mus'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape)
Expand Down
54 changes: 54 additions & 0 deletions rl_games/configs/smac/v1/3m_torch_sa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
params:
algo:
name: a2c_discrete

model:
name: multi_discrete_a2c

network:
name: actor_critic
separate: True
#normalization: layer_norm
space:
multi_discrete:

mlp:
units: [256, 128]
activation: relu
initializer:
name: default
regularizer:
name: None
config:
name: 3m_sa
reward_shaper:
scale_value: 1
normalize_advantage: True
gamma: 0.99
tau: 0.95
learning_rate: 5e-4
score_to_win: 20
grad_norm: 0.5
entropy_coef: 0.001
truncate_grads: True
env_name: smac
e_clip: 0.2
clip_value: True
num_actors: 8
horizon_length: 128
minibatch_size: 512
mini_epochs: 4
critic_coef: 1
lr_schedule: None
kl_threshold: 0.05
normalize_input: True
use_action_masks: True
ignore_dead_batches : False

env_config:
name: 3m
frames: 1
transpose: False
random_invalid_step: False
as_single_agent: True
central_value: True
60 changes: 60 additions & 0 deletions rl_games/configs/smac/v1/5m_vs_6m_sa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
params:
algo:
name: a2c_discrete

model:
name: multi_discrete_a2c

network:
name: actor_critic
separate: True
space:
multi_discrete:

mlp:
units: [512, 256, 128]
activation: relu
initializer:
name: default

config:
name: 5m_vs_6m_sa
reward_shaper:
scale_value: 1
normalize_advantage: True
gamma: 0.99
tau: 0.95
learning_rate: 3e-4
score_to_win: 20
entropy_coef: 0.02
truncate_grads: True
grad_norm: 1
env_name: smac
e_clip: 0.2
clip_value: False
num_actors: 8
horizon_length: 256
minibatch_size: 1024
mini_epochs: 4
critic_coef: 2
lr_schedule: None
kl_threshold: 0.05
normalize_input: True
normalize_value: False
use_action_masks: True
use_diagnostics: True
seq_length: 8
max_epochs: 10000
env_config:
name: 5m_vs_6m
central_value: True
reward_only_positive: True
obs_last_action: False
apply_agent_ids: False
as_single_agent: True

player:
render: False
games_num: 200
n_game_life: 1
determenistic: True
34 changes: 33 additions & 1 deletion rl_games/envs/smac_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,38 @@ def get_action_mask(self):
def has_action_mask(self):
return not self.random_invalid_step

def seed(self, _):
def seed(self, seed):
pass
#self.env.seed(seed)

class MultiDiscreteSmacWrapper(gym.Env):
def __init__(self, env):
gym.Env.__init__(self)
self.env = env
self.observation_space = self.env.state_space
self.action_space = gym.spaces.Tuple([self.env.action_space] * self.env.get_number_of_agents())

def step(self, actions):
fixed_rewards = None
obses, reward, done, info = self.env.step(actions)
return obses['state'], reward[0], done[0], info

def reset(self):
obses = self.env.reset()
return obses['state']

def has_action_mask(self):
return self.env.has_action_mask()

def get_action_mask(self):
action_maks = self.env.get_action_mask()
action_maks = action_maks.flatten()
return np.expand_dims(action_maks, axis=0)

def get_number_of_agents(self):
return 1

def seed(self, seed):
pass
#self.env.seed(seed)

0 comments on commit a79fcc2

Please sign in to comment.