Skip to content

Commit

Permalink
cleanup initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DenSumy committed Oct 8, 2023
1 parent fe95913 commit 3ad7c3f
Show file tree
Hide file tree
Showing 63 changed files with 89 additions and 263 deletions.
14 changes: 0 additions & 14 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ def create_slime_gym_env(**kwargs):
env = gym.make(name, **kwargs)
return env

def create_connect_four_env(**kwargs):
from rl_games.envs.connect4_selfplay import ConnectFourSelfPlay
name = kwargs.pop('name')
limit_steps = kwargs.pop('limit_steps', False)
self_play = kwargs.pop('self_play', False)
if self_play:
env = ConnectFourSelfPlay(name, **kwargs)
else:
env = gym.make(name, **kwargs)
return env

def create_atari_gym_env(**kwargs):
#frames = kwargs.pop('frames', 1)
Expand Down Expand Up @@ -391,10 +381,6 @@ def create_env(name, **kwargs):
'env_creator' : lambda **kwargs : create_minigrid_env(kwargs.pop('name'), **kwargs),
'vecenv_type' : 'RAY'
},
'connect4_env' : {
'env_creator' : lambda **kwargs : create_connect_four_env(**kwargs),
'vecenv_type' : 'RAY'
},
'multiwalker_env' : {
'env_creator' : lambda **kwargs : create_multiwalker_env(**kwargs),
'vecenv_type' : 'RAY'
Expand Down
4 changes: 2 additions & 2 deletions rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, size, ob_space):
self._next_obses = np.zeros((size,) + ob_space.shape, dtype=ob_space.dtype)
self._rewards = np.zeros(size)
self._actions = np.zeros(size, dtype=np.int32)
self._dones = np.zeros(size, dtype=np.bool)
self._dones = np.zeros(size, dtype=bool)

self._maxsize = size
self._next_idx = 0
Expand Down Expand Up @@ -341,7 +341,7 @@ def _init_from_env_info(self, env_info):
if self.is_discrete or self.is_multi_discrete:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=int), obs_base_shape)
if self.use_action_masks:
self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=np.bool), obs_base_shape)
self.tensor_dict['action_masks'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape + (np.sum(self.actions_num),), dtype=bool), obs_base_shape)
if self.is_continuous:
self.tensor_dict['actions'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape)
self.tensor_dict['mus'] = self._create_tensor_from_space(gym.spaces.Box(low=0, high=1,shape=self.actions_shape, dtype=np.float32), obs_base_shape)
Expand Down
25 changes: 13 additions & 12 deletions rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from time import sleep
import torch


class RayWorker:
def __init__(self, config_name, config):
self.env = configurations[config_name]['env_creator'](**config)
Expand Down Expand Up @@ -96,30 +95,32 @@ def get_env_info(self):


class RayVecEnv(IVecEnv):
import ray

def __init__(self, config_name, num_actors, **kwargs):
self.config_name = config_name
self.num_actors = num_actors
self.use_torch = False
self.seed = kwargs.pop('seed', None)

import ray
self.remote_worker = ray.remote(RayWorker)

self.remote_worker = self.ray.remote(RayWorker)
self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)]

if self.seed is not None:
seeds = range(self.seed, self.seed + self.num_actors)
seed_set = []
for (seed, worker) in zip(seeds, self.workers):
seed_set.append(worker.seed.remote(seed))
ray.get(seed_set)
self.ray.get(seed_set)

res = self.workers[0].get_number_of_agents.remote()
self.num_agents = ray.get(res)
self.num_agents = self.ray.get(res)

res = self.workers[0].get_env_info.remote()
env_info = ray.get(res)
env_info = self.ray.get(res)
res = self.workers[0].can_concat_infos.remote()
can_concat_infos = ray.get(res)
can_concat_infos = self.ray.get(res)
self.use_global_obs = env_info['use_global_observations']
self.concat_infos = can_concat_infos
self.obs_type_dict = type(env_info.get('observation_space')) is gym.spaces.Dict
Expand All @@ -139,7 +140,7 @@ def step(self, actions):
for num, worker in enumerate(self.workers):
res_obs.append(worker.step.remote(actions[self.num_agents * num: self.num_agents * num + self.num_agents]))

all_res = ray.get(res_obs)
all_res = self.ray.get(res_obs)
for res in all_res:
cobs, crewards, cdones, cinfos = res
if self.use_global_obs:
Expand Down Expand Up @@ -171,27 +172,27 @@ def step(self, actions):

def get_env_info(self):
res = self.workers[0].get_env_info.remote()
return ray.get(res)
return self.ray.get(res)

def set_weights(self, indices, weights):
res = []
for ind in indices:
res.append(self.workers[ind].set_weights.remote(weights))
ray.get(res)
self.ray.get(res)

def has_action_masks(self):
return True

def get_action_masks(self):
mask = [worker.get_action_mask.remote() for worker in self.workers]
masks = ray.get(mask)
masks = self.ray.get(mask)
return np.concatenate(masks, axis=0)

def reset(self):
res_obs = [worker.reset.remote() for worker in self.workers]
newobs, newstates = [],[]
for res in res_obs:
cobs = ray.get(res)
cobs = self.ray.get(res)
if self.use_global_obs:
newobs.append(cobs["obs"])
newstates.append(cobs["state"])
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
73 changes: 73 additions & 0 deletions rl_games/configs/smac/v2/5z_torch_cv.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
params:
algo:
name: a2c_discrete

model:
name: discrete_a2c

network:
name: actor_critic
separate: False
#normalization: layer_norm
space:
discrete:

mlp:
units: [256, 128]
activation: relu
initializer:
name: default
regularizer:
name: 'None'

config:
name: 5z_cv
reward_shaper:
scale_value: 1
normalize_advantage: True
gamma: 0.99
tau: 0.95
learning_rate: 5e-4
score_to_win: 20
grad_norm: 0.5
entropy_coef: 0.001
truncate_grads: True
env_name: smac
e_clip: 0.2
clip_value: True
num_actors: 8
horizon_length: 128
minibatch_size: 1536 # 3 * 512
mini_epochs: 4
critic_coef: 1
lr_schedule: None
kl_threshold: 0.05
normalize_input: True
normalize_value: False
use_action_masks: True
ignore_dead_batches : False

env_config:
name: zerg_5_vs_5
frames: 1
transpose: False
random_invalid_step: False
central_value: True
reward_only_positive: True
central_value_config:
minibatch_size: 512
mini_epochs: 4
learning_rate: 5e-4
clip_value: False
normalize_input: True
network:
name: actor_critic
central_value: True
mlp:
units: [256, 128]
activation: relu
initializer:
name: default
scale: 2
regularizer:
name: 'None'
Empty file removed rl_games/distributed/__init__.py
Empty file.
2 changes: 0 additions & 2 deletions rl_games/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@


from rl_games.envs.connect4_network import ConnectBuilder
from rl_games.envs.test_network import TestNetBuilder
from rl_games.algos_torch import model_builder

model_builder.register_network('connect4net', ConnectBuilder)
model_builder.register_network('testnet', TestNetBuilder)
99 changes: 0 additions & 99 deletions rl_games/envs/connect4_network.py

This file was deleted.

Loading

0 comments on commit 3ad7c3f

Please sign in to comment.