Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VDN #22

Open
wants to merge 79 commits into
base: master
Choose a base branch
from
Open

VDN #22

Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
f899a23
interfaced with sacred
schroederdewitt May 25, 2020
116c9e2
add experiment infrastructure
schroederdewitt May 25, 2020
49a41a5
fixed docker
schroederdewitt May 25, 2020
81ae28d
fixes for docker file
May 25, 2020
a8d6b61
fixed Dockerfile
May 25, 2020
680f8f5
torch runner update
schroederdewitt May 26, 2020
3c6f54b
fixes
schroederdewitt May 26, 2020
b2aacac
fixed run.sh
May 26, 2020
e3bbc33
scalar logging problem
schroederdewitt May 26, 2020
1743460
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt May 26, 2020
ac724fc
added some config
schroederdewitt May 28, 2020
caa084b
added shell scripts
schroederdewitt May 28, 2020
130f78f
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt May 28, 2020
641344a
added 3s_vs_5z configs
schroederdewitt May 28, 2020
0f004a5
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt May 28, 2020
a9d06ae
minor
schroederdewitt May 28, 2020
05310ac
minor
schroederdewitt May 28, 2020
e7fcda2
minor
schroederdewitt May 28, 2020
7061aef
added MM2_torch.yaml
schroederdewitt May 28, 2020
902f789
lunch
May 28, 2020
b165180
Merge branch 'master' of github.com:schroederdewitt/rl_games
May 28, 2020
7b4f1be
interfaced tf code
schroederdewitt May 29, 2020
d5b3e37
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt May 29, 2020
e2598d2
added tf baselines
schroederdewitt May 29, 2020
5503326
added more config
schroederdewitt May 30, 2020
dc5f049
added additional maps
schroederdewitt Jun 14, 2020
fe71d50
fix
schroederdewitt Jun 14, 2020
f7a54c6
vdn start
tarun018 Jun 21, 2020
8b9c3db
up
tarun018 Jun 21, 2020
d7e0d91
updates
tarun018 Jun 21, 2020
76fbc52
state added to exp replay
tarun018 Jun 21, 2020
8dc9c61
rudimentary vdn ready
tarun018 Jun 21, 2020
b3a2870
vdn as a model
tarun018 Jun 21, 2020
55b60bf
vdn conf
tarun018 Jun 22, 2020
b22648f
env config use in vdn
tarun018 Jun 22, 2020
fb0161d
bug correct
tarun018 Jun 22, 2020
c6fa74f
updated MMM2_torch.yaml in order to re-benchmark
schroederdewitt Jun 22, 2020
cbf85fe
grad norm with truncate option and a bug update
tarun018 Jun 22, 2020
279bf6a
bug correct
tarun018 Jun 22, 2020
e8e8f7e
major changes
tarun018 Jun 22, 2020
650c741
final update
tarun018 Jun 22, 2020
072d2ac
created ReplayBufferCentralState
schroederdewitt Jun 23, 2020
a73b87c
Merge pull request #1 from schroederdewitt/vdn
schroederdewitt Jun 23, 2020
631b4f3
init for centralized state
schroederdewitt Jun 23, 2020
506cd70
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt Jun 23, 2020
4b805f4
vdn implementation fix
schroederdewitt Jun 23, 2020
2a8f87f
added vdn config files
schroederdewitt Jun 23, 2020
f9a1abe
minor fixes
schroederdewitt Jun 30, 2020
24ba401
add plotting and examples
mingfeisun Jul 1, 2020
159b8f7
fix for VDN logger
schroederdewitt Jul 2, 2020
d804a25
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt Jul 2, 2020
e2a0833
bug fix for state shape
tarun018 Jul 2, 2020
554258c
Merge pull request #2 from schroederdewitt/vdn_s
schroederdewitt Jul 2, 2020
a3c5a6d
minor logging fix
schroederdewitt Jul 2, 2020
7c82843
fixed rl_games dockerfile cuda version
schroederdewitt Jul 15, 2020
d48e63c
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt Jul 15, 2020
31c08ab
minor
Jul 15, 2020
8f8be0c
fixed launch servers
Jul 15, 2020
c19537f
minor
schroederdewitt Jul 15, 2020
4e0999f
Merge branch 'master' of github.com:schroederdewitt/rl_games
schroederdewitt Jul 15, 2020
aa2ae51
updated experience replay buffer
tarun018 Jul 21, 2020
19dbf55
iql with normal buffer
tarun018 Jul 21, 2020
a7ced6d
bug update
tarun018 Jul 25, 2020
8c7d251
update
tarun018 Jul 25, 2020
f97e82a
updated config
tarun018 Jul 25, 2020
42a954d
test dynamic growth
tarun018 Jul 25, 2020
a67d631
no devide placement
tarun018 Jul 25, 2020
0e4269a
config update
tarun018 Jul 25, 2020
79916cf
added stag_hunt (not yet fully working)
schroederdewitt Jul 25, 2020
e2fdc72
dockerfile with cpu
tarun018 Aug 16, 2020
6ede081
Merge branch 'master' of github.com:schroederdewitt/rl_games
tarun018 Aug 16, 2020
a971642
removed central state code
schroederdewitt Aug 24, 2020
6ceb9e1
added staghunt (no central state) for ppo
schroederdewitt Aug 24, 2020
bfaee18
staghunt print fix
schroederdewitt Aug 24, 2020
bf11583
stag hunt fix
Oct 18, 2020
438a4a5
Merge branch 'master' of github.com:schroederdewitt/rl_games into master
Oct 18, 2020
466ef5e
new config
tarun018 Dec 24, 2020
91f9ea5
update max epochs
tarun018 Dec 27, 2020
444e71e
new c
tarun018 Jan 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,8 @@ venv.bak/
.vscode

/nn
/runs
/runs
db_config.private.yaml
exp_scripts/
.idea/
analysis/
1 change: 1 addition & 0 deletions 3rdparty/gym_0_10_8
1 change: 1 addition & 0 deletions 3rdparty/multiagent_mujoco
1 change: 1 addition & 0 deletions 3rdparty/multiagent_particle_envs
107 changes: 89 additions & 18 deletions algos_tf14/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ def swap_and_flatten01(arr):
return arr.swapaxes(0, 1).reshape(s[0] * s[1], *s[2:])

class A2CAgent:
def __init__(self, sess, base_name, observation_space, action_space, config):
def __init__(self, sess, base_name, observation_space, action_space, config, logger, central_state_space=None):
observation_shape = observation_space.shape

self.use_central_states = False
if central_state_space is not None:
self.use_central_states = True
central_state_shape = central_state_space.shape

self.use_action_masks = config.get('use_action_masks', False)
self.is_train = config.get('is_train', True)
self.self_play = config.get('self_play', False)
Expand Down Expand Up @@ -67,6 +73,8 @@ def __init__(self, sess, base_name, observation_space, action_space, config):
self.game_lengths = deque([], maxlen=self.games_to_log)
self.game_scores = deque([], maxlen=self.games_to_log)
self.obs_ph = tf.placeholder(observation_space.dtype, (None, ) + observation_shape, name = 'obs')
if self.use_central_states:
self.central_states_ph = tf.placeholder(central_state_space.dtype, (None, ) + central_state_shape, name = 'central_state')
self.target_obs_ph = tf.placeholder(observation_space.dtype, (None, ) + observation_shape, name = 'target_obs')
self.actions_num = action_space.n
self.actions_ph = tf.placeholder('int32', (None,), name = 'actions')
Expand All @@ -84,6 +92,9 @@ def __init__(self, sess, base_name, observation_space, action_space, config):
self.update_epoch_op = self.epoch_num.assign(self.epoch_num + 1)
self.current_lr = self.learning_rate_ph

#if self.use_central_states:
# self.input_obs = self.central_states_ph
#else:
self.input_obs = self.obs_ph
self.input_target_obs = self.target_obs_ph

Expand Down Expand Up @@ -114,6 +125,9 @@ def __init__(self, sess, base_name, observation_space, action_space, config):
'action_mask_ph' : None
}

if self.use_central_states:
self.train_dict["central_states"] = self.central_states_ph

self.run_dict = {
'name' : 'agent',
'inputs' : self.input_target_obs,
Expand All @@ -124,11 +138,14 @@ def __init__(self, sess, base_name, observation_space, action_space, config):
'action_mask_ph' : self.action_mask_ph
}

self.states = None
if self.use_central_states:
self.train_dict["central_states"] = self.central_states_ph

self.rnn_states = None
if self.network.is_rnn():
self.logp_actions ,self.state_values, self.action, self.entropy, self.states_ph, self.masks_ph, self.lstm_state, self.initial_state = self.network(self.train_dict, reuse=False)
self.logp_actions, self.state_values, self.action, self.entropy, self.rnn_states_ph, self.masks_ph, self.lstm_state, self.initial_state = self.network(self.train_dict, reuse=False)
self.target_neglogp, self.target_state_values, self.target_action, _, self.target_states_ph, self.target_masks_ph, self.target_lstm_state, self.target_initial_state, self.logits = self.network(self.run_dict, reuse=True)
self.states = self.target_initial_state
self.rnn_states = self.target_initial_state

else:
self.logp_actions ,self.state_values, self.action, self.entropy = self.network(self.train_dict, reuse=False)
Expand All @@ -142,6 +159,10 @@ def __init__(self, sess, base_name, observation_space, action_space, config):

self.sess.run(tf.global_variables_initializer())

self.logger = logger

self.num_env_steps_train = 0

def setup_losses(self):
curr_e_clip = self.e_clip * self.lr_multiplier
if (self.ppo):
Expand Down Expand Up @@ -192,22 +213,22 @@ def get_action_values(self, obs):
run_ops = [self.target_action, self.target_state_values, self.target_neglogp]
if self.network.is_rnn():
run_ops.append(self.target_lstm_state)
return self.sess.run(run_ops, {self.target_obs_ph : obs, self.target_states_ph : self.states, self.target_masks_ph : self.dones})
return self.sess.run(run_ops, {self.target_obs_ph : obs, self.target_states_ph : self.rnn_states, self.target_masks_ph : self.dones})
else:
return (*self.sess.run(run_ops, {self.target_obs_ph : obs}), None)

def get_masked_action_values(self, obs, action_masks):
run_ops = [self.target_action, self.target_state_values, self.target_neglogp, self.logits]
if self.network.is_rnn():
run_ops.append(self.target_lstm_state)
return self.sess.run(run_ops, {self.action_mask_ph: action_masks, self.target_obs_ph : obs, self.target_states_ph : self.states, self.target_masks_ph : self.dones})
return self.sess.run(run_ops, {self.action_mask_ph: action_masks, self.target_obs_ph : obs, self.target_states_ph : self.rnn_states, self.target_masks_ph : self.dones})
else:
return (*self.sess.run(run_ops, {self.action_mask_ph: action_masks, self.target_obs_ph : obs}), None)


def get_values(self, obs):
if self.network.is_rnn():
return self.sess.run([self.target_state_values], {self.target_obs_ph : obs, self.target_states_ph : self.states, self.target_masks_ph : self.dones})
return self.sess.run([self.target_state_values], {self.target_obs_ph : obs, self.target_states_ph : self.rnn_states, self.target_masks_ph : self.dones})
else:
return self.sess.run([self.target_state_values], {self.target_obs_ph : obs})

Expand All @@ -222,33 +243,44 @@ def set_weights(self, weights):
def play_steps(self):
# here, we init the lists that will contain the mb of experiences
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]

mb_states = []

if self.use_central_states:
mb_central_states = []

mb_rnn_states = []
epinfos = []

# for n in range number of steps
for _ in range(self.steps_num):
if self.network.is_rnn():
mb_states.append(self.states)
mb_rnn_states.append(self.rnn_states)

if self.use_action_masks:
masks = self.vec_env.get_action_masks()

if self.use_action_masks:
actions, values, neglogpacs, _, self.states = self.get_masked_action_values(self.obs, masks)
actions, values, neglogpacs, _, self.rnn_states = self.get_masked_action_values(self.obs, masks)
else:
actions, values, neglogpacs, self.states = self.get_action_values(self.obs)
actions, values, neglogpacs, self.rnn_states = self.get_action_values(self.obs)

actions = np.squeeze(actions)
values = np.squeeze(values)
neglogpacs = np.squeeze(neglogpacs)
mb_obs.append(self.obs.copy())
if self.use_central_states:
mb_central_states.append(self.central_states.copy())
mb_actions.append(actions)
mb_values.append(values)
mb_neglogpacs.append(neglogpacs)
mb_dones.append(self.dones.copy())

self.obs[:], rewards, self.dones, infos = self.vec_env.step(actions)
if self.use_central_states:
self.central_states[:] = self.vec_env.get_states()

# Increase step count by self.num_actors (WHIRL)
self.num_env_steps_train += self.num_actors

self.current_rewards += rewards

self.current_lengths += 1
Expand All @@ -268,12 +300,14 @@ def play_steps(self):

#using openai baseline approach
mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
if self.use_central_states:
mb_central_states = np.asarray(mb_central_states, dtype=self.obs.dtype)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
mb_actions = np.asarray(mb_actions, dtype=np.float32)
mb_values = np.asarray(mb_values, dtype=np.float32)
mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
mb_dones = np.asarray(mb_dones, dtype=np.bool)
mb_states = np.asarray(mb_states, dtype=np.float32)
mb_rnn_states = np.asarray(mb_rnn_states, dtype=np.float32)
last_values = self.get_values(self.obs)
last_values = np.squeeze(last_values)

Expand All @@ -294,9 +328,19 @@ def play_steps(self):

mb_returns = mb_advs + mb_values
if self.network.is_rnn():
result = (*map(swap_and_flatten01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_states )), epinfos)
if self.use_central_states:
result = (*map(swap_and_flatten01,
(mb_central_states, mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_rnn_states)),
epinfos)
else:
result = (*map(swap_and_flatten01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_rnn_states )), epinfos)
else:
result = (*map(swap_and_flatten01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)), None, epinfos)
if self.use_central_states:
result = (
*map(swap_and_flatten01, (mb_central_states, mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)), None,
epinfos)
else:
result = (*map(swap_and_flatten01, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs)), None, epinfos)
return result

def save(self, fn):
Expand All @@ -307,6 +351,8 @@ def restore(self, fn):

def train(self):
self.obs = self.vec_env.reset()
if self.use_central_states:
self.central_states = self.vec_env.get_states()
batch_size = self.steps_num * self.num_actors * self.num_agents
batch_size_envs = self.steps_num * self.num_actors
minibatch_size = self.config['minibatch_size']
Expand All @@ -327,7 +373,10 @@ def train(self):
play_time_start = time.time()
epoch_num = self.update_epoch()
frame += batch_size_envs
obses, returns, dones, actions, values, neglogpacs, lstm_states, _ = self.play_steps()
if self.use_central_states:
central_states, obses, returns, dones, actions, values, neglogpacs, lstm_states, _ = self.play_steps()
else:
obses, returns, dones, actions, values, neglogpacs, lstm_states, _ = self.play_steps()
advantages = returns - values
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
Expand Down Expand Up @@ -363,7 +412,7 @@ def train(self):
dict[self.obs_ph] = obses[mbatch]
dict[self.masks_ph] = dones[mbatch]

dict[self.states_ph] = lstm_states[batch]
dict[self.rnn_states_ph] = lstm_states[batch]

dict[self.learning_rate_ph] = last_lr
run_ops = [self.actor_loss, self.critic_loss, self.entropy, self.kl_approx, self.current_lr, self.lr_multiplier, self.train_op]
Expand All @@ -383,11 +432,13 @@ def train(self):
values = values[permutation]
neglogpacs = neglogpacs[permutation]
advantages = advantages[permutation]
central_states = central_states[permutation]

for i in range(0, num_minibatches):
batch = range(i * minibatch_size, (i + 1) * minibatch_size)
dict = {self.obs_ph: obses[batch], self.actions_ph : actions[batch], self.rewards_ph : returns[batch],
self.advantages_ph : advantages[batch], self.old_logp_actions_ph : neglogpacs[batch], self.old_values_ph : values[batch]}
self.advantages_ph : advantages[batch], self.old_logp_actions_ph : neglogpacs[batch], self.old_values_ph : values[batch],
self.central_states_ph: central_states[batch]}
dict[self.learning_rate_ph] = last_lr
run_ops = [self.actor_loss, self.critic_loss, self.entropy, self.kl_approx, self.current_lr, self.lr_multiplier, self.train_op]

Expand Down Expand Up @@ -417,6 +468,18 @@ def train(self):
self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame)
self.writer.add_scalar('info/kl', np.mean(kls), frame)
self.writer.add_scalar('epochs', epoch_num, frame)

self.logger.log_stat("whirl/performance/fps", batch_size / scaled_time, self.num_env_steps_train)
self.logger.log_stat("whirl/performance/upd_time", update_time, self.num_env_steps_train)
self.logger.log_stat("whirl/performance/play_time", play_time, self.num_env_steps_train)
self.logger.log_stat("whirl/losses/a_loss", np.asscalar(np.mean(a_losses)), self.num_env_steps_train)
self.logger.log_stat("whirl/losses/c_loss", np.asscalar(np.mean(c_losses)), self.num_env_steps_train)
self.logger.log_stat("whirl/losses/entropy", np.asscalar(np.mean(entropies)), self.num_env_steps_train)
self.logger.log_stat("whirl/info/last_lr", last_lr * lr_mul, self.num_env_steps_train)
self.logger.log_stat("whirl/info/lr_mul", lr_mul, self.num_env_steps_train)
self.logger.log_stat("whirl/info/e_clip", self.e_clip * lr_mul, self.num_env_steps_train)
self.logger.log_stat("whirl/info/kl", np.asscalar(np.mean(kls)), self.num_env_steps_train)
self.logger.log_stat("whirl/epochs", epoch_num, self.num_env_steps_train)

if len(self.game_rewards) > 0:
mean_rewards = np.mean(self.game_rewards)
Expand All @@ -429,6 +492,14 @@ def train(self):
self.writer.add_scalar('win_rate/mean', mean_scores, frame)
self.writer.add_scalar('win_rate/time', mean_scores, total_time)

self.logger.log_stat("whirl/rewards/mean", np.asscalar(mean_rewards), self.num_env_steps_train)
self.logger.log_stat("whirl/rewards/time", mean_rewards, total_time)
self.logger.log_stat("whirl/episode_lengths/mean", np.asscalar(mean_lengths), self.num_env_steps_train)
self.logger.log_stat("whirl/episode_lengths/time", mean_lengths, total_time)
self.logger.log_stat("whirl/win_rate/mean", np.asscalar(mean_scores), self.num_env_steps_train)
self.logger.log_stat("whirl/win_rate/time", mean_scores, total_time)


if rep_count % 10 == 0:
self.save("./nn/" + 'last_' + self.config['name'] + 'ep=' + str(epoch_num) + 'rew=' + str(mean_rewards))
rep_count += 1
Expand Down
20 changes: 17 additions & 3 deletions algos_tf14/dqnagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from common.categorical import CategoricalQ

class DQNAgent:
def __init__(self, sess, base_name, observation_space, action_space, config):
def __init__(self, sess, base_name, observation_space, action_space, config, logger):
observation_shape = observation_space.shape
actions_num = action_space.n
self.config = config
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self, sess, base_name, observation_space, action_space, config):
self.epsilon_processor = tr_helpers.LinearValueProcessor(self.config['epsilon'], self.config['min_epsilon'], self.config['epsilon_decay_frames'])
self.beta_processor = tr_helpers.LinearValueProcessor(self.config['priority_beta'], self.config['max_beta'], self.config['beta_decay_frames'])
if self.env_name:
self.env = env_configurations.configurations[self.env_name]['env_creator']()
self.env = env_configurations.configurations[self.env_name]['env_creator'](name=config['name'])
self.sess = sess
self.steps_num = self.config['steps_num']
self.states = deque([], maxlen=self.steps_num)
Expand Down Expand Up @@ -402,7 +402,16 @@ def train(self):
self.writer.add_scalar('info/epsilon', self.epsilon, frame)
if self.is_prioritized:
self.writer.add_scalar('beta', self.beta, frame)


self.logger.log_stat("whirl/performance/fps", 1000 / sum_time, self.num_env_steps_train)
self.logger.log_stat("whirl/performance/upd_time", update_time, self.num_env_steps_train)
self.logger.log_stat("whirl/performance/play_time", play_time, self.num_env_steps_train)
self.logger.log_stat("losses/td_loss", np.mean(losses), self.num_env_steps_train)
self.logger.log_stat("whirl/info/last_lr", self.learning_rate*lr_mul, self.num_env_steps_train)
self.logger.log_stat("whirl/info/lr_mul", lr_mul, self.num_env_steps_train)
self.logger.log_stat("whirl/epochs", epoch_num, self.num_env_steps_train)
self.logger.log_stat("whirl/epsilon", self.epsilon, self.num_env_steps_train)

update_time = 0
play_time = 0
num_games = len(self.game_rewards)
Expand All @@ -415,6 +424,11 @@ def train(self):
self.writer.add_scalar('episode_lengths/mean', mean_lengths, frame)
self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time)

self.logger.log_stat("whirl/rewards/mean", np.asscalar(mean_rewards), self.num_env_steps_train)
self.logger.log_stat("whirl/rewards/time", mean_rewards, total_time)
self.logger.log_stat("whirl/episode_lengths/mean", np.asscalar(mean_lengths), self.num_env_steps_train)
self.logger.log_stat("whirl/episode_lengths/time", mean_lengths, total_time)

if mean_rewards > last_mean_rewards:
print('saving next best rewards: ', mean_rewards)
last_mean_rewards = mean_rewards
Expand Down
1 change: 1 addition & 0 deletions algos_tf14/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self):
self.model_factory.register_builder('continuous_a2c_lstm', lambda network, **kwargs : models.LSTMModelA2CContinuous(network))
self.model_factory.register_builder('continuous_a2c_lstm_logstd', lambda network, **kwargs : models.LSTMModelA2CContinuousLogStd(network))
self.model_factory.register_builder('dqn', lambda network, **kwargs : models.AtariDQN(network))
self.model_factory.register_builder('vdn', lambda network, **kwargs : models.VDN_DQN(network))


self.network_factory = object_factory.ObjectFactory()
Expand Down
Loading