Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed various issues with num_frames #213

Merged
merged 4 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import torch
from torch import nn
import numpy as np
import gym


class DiscreteA2CAgent(a2c_common.DiscreteA2CBase):

def __init__(self, base_name, params):
a2c_common.DiscreteA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape
Expand All @@ -21,7 +22,7 @@ def __init__(self, base_name, params):
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'num_seqs' : self.num_actors * self.num_agents,
'value_size': self.env_info.get('value_size',1),
'value_size': self.env_info.get('value_size', 1),
'normalize_value': self.normalize_value,
'normalize_input': self.normalize_input,
}
Expand All @@ -48,6 +49,7 @@ def __init__(self, base_name, params):
'config' : self.central_value_config,
'writter' : self.writer,
'max_epochs' : self.max_epochs,
'max_frames' : self.max_frames,
'multi_gpu' : self.multi_gpu,
}
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)
Expand Down
16 changes: 12 additions & 4 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
from rl_games.common import datasets
from rl_games.common import schedulers


class CentralValueTrain(nn.Module):
def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, seq_len, normalize_value,network, config, writter, max_epochs, multi_gpu):

def __init__(self, state_shape, value_size, ppo_device, num_agents, \
horizon_length, num_actors, num_actions, seq_len, \
normalize_value,network, config, writter, max_epochs, multi_gpu):
nn.Module.__init__(self)

self.ppo_device = ppo_device
self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len
self.normalize_value = normalize_value
Expand All @@ -37,16 +42,19 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
self.model = network.build(state_config)
self.lr = float(config['learning_rate'])
self.linear_lr = config.get('lr_schedule') == 'linear'

# todo: support max frames as well
if self.linear_lr:
self.scheduler = schedulers.LinearScheduler(self.lr,
max_steps=self.max_epochs,
apply_to_entropy=False,
start_entropy_coef=0)
max_steps = self.max_epochs,
apply_to_entropy = False,
start_entropy_coef = 0)
else:
self.scheduler = schedulers.IdentityScheduler()

self.mini_epoch = config['mini_epochs']
assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config))

self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)
self.num_minibatches = self.horizon_length * self.num_actors // self.minibatch_size
Expand Down
45 changes: 34 additions & 11 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from rl_games.algos_torch import torch_ext

from rl_games.algos_torch.running_mean_std import RunningMeanStd

from rl_games.common import vecenv
from rl_games.common import schedulers
from rl_games.common import experience
from rl_games.common.a2c_common import print_statistics

from rl_games.interfaces.base_algorithm import BaseAlgorithm
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -18,6 +17,7 @@
import time
import os


class SACAgent(BaseAlgorithm):

def __init__(self, base_name, params):
Expand Down Expand Up @@ -131,7 +131,8 @@ def base_init(self, base_name, config):
self.rnn_states = None
self.name = base_name

self.max_epochs = self.config.get('max_epochs', 1e6)
self.max_epochs = self.config.get('max_epochs', -1)
self.max_frames = self.config.get('max_frames', -1)

self.network = config['network']
self.rewards_shaper = config['reward_shaper']
Expand All @@ -150,6 +151,7 @@ def base_init(self, base_name, config):
self.last_mean_rewards = -100500
self.play_time = 0
self.epoch_num = 0

# TODO: put it into the separate class
pbt_str = ''
self.population_based_training = config.get('population_based_training', False)
Expand Down Expand Up @@ -522,8 +524,8 @@ def train(self):
fps_step_inference = curr_frames / play_time
fps_total = curr_frames / epoch_total_time

if self.print_stats:
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {self.epoch_num}/{self.max_epochs}')
print_statistics(self.print_stats, curr_frames, step_time, play_time, epoch_total_time,
self.epoch_num, self.max_epochs, self.frame, self.max_frames)

self.writer.add_scalar('performance/step_inference_rl_update_fps', fps_total, self.frame)
self.writer.add_scalar('performance/step_inference_fps', fps_step_inference, self.frame)
Expand Down Expand Up @@ -554,18 +556,39 @@ def train(self):
self.writer.add_scalar('episode_lengths/step', mean_lengths, self.frame)
self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time)
checkpoint_name = self.config['name'] + '_ep_' + str(self.epoch_num) + '_rew_' + str(mean_rewards)

should_exit = False

if mean_rewards > self.last_mean_rewards and self.epoch_num >= self.save_best_after:
print('saving next best rewards: ', mean_rewards)
self.last_mean_rewards = mean_rewards
self.save(os.path.join(self.nn_dir, self.config['name']))
if self.last_mean_rewards > self.config.get('score_to_win', float('inf')):
print('Network won!')
print('Maximum reward achieved. Network won!')
self.save(os.path.join(self.nn_dir, checkpoint_name))
return self.last_mean_rewards, self.epoch_num
should_exit = True

if self.epoch_num >= self.max_epochs:
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(self.epoch_num) + 'rew' + str(mean_rewards)))
print('MAX EPOCHS NUM!')
return self.last_mean_rewards, self.epoch_num
if self.epoch_num >= self.max_epochs and self.max_epochs != -1:
if self.game_rewards.current_size == 0:
print('WARNING: Max epochs reached before any env terminated at least once')
mean_rewards = -np.inf

self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_ep_' + str(epoch_num) \
+ '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX EPOCHS NUM!')
should_exit = True

if self.frame >= self.max_frames and self.max_frames != -1:
if self.game_rewards.current_size == 0:
print('WARNING: Max frames reached before any env terminated at least once')
mean_rewards = -np.inf

self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_frame_' + str(self.frame) \
+ '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX FRAMES NUM!')
should_exit = True

update_time = 0

if should_exit:
return self.last_mean_rewards, self.epoch_num
120 changes: 91 additions & 29 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,27 @@ def rescale_actions(low, high, action):
return scaled_action


def print_statistics(print_stats, curr_frames, step_time, step_inference_time, total_time, epoch_num, max_epochs, frame, max_frames):
if print_stats:
step_time = max(step_time, 1e-9)
fps_step = curr_frames / step_time
fps_step_inference = curr_frames / step_inference_time
fps_total = curr_frames / total_time

if max_epochs == -1 and max_frames == -1:
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num:.0f} frames: {frame:.0f}')
elif max_epochs == -1:
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num:.0f} frames: {frame:.0f}/{max_frames:.0f}')
elif max_frames == -1:
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num:.0f}/{max_epochs:.0f} frames: {frame:.0f}')
else:
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num:.0f}/{max_epochs:.0f} frames: {frame:.0f}/{max_frames:.0f}')


class A2CBase(BaseAlgorithm):

def __init__(self, base_name, params):

self.config = config = params['config']
pbt_str = ''
self.population_based_training = config.get('population_based_training', False)
Expand Down Expand Up @@ -133,20 +152,36 @@ def __init__(self, base_name, params):
self.name = base_name

self.ppo = config.get('ppo', True)
self.max_epochs = self.config.get('max_epochs', 1e6)
self.max_frames = self.config.get('max_frames', 1e10)
self.max_epochs = self.config.get('max_epochs', -1)
self.max_frames = self.config.get('max_frames', -1)

self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
self.linear_lr = config['lr_schedule'] == 'linear'
self.schedule_type = config.get('schedule_type', 'legacy')

# Setting learning rate scheduler
if self.is_adaptive_lr:
self.kl_threshold = config['kl_threshold']
self.scheduler = schedulers.AdaptiveScheduler(self.kl_threshold)

elif self.linear_lr:
self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']),
max_steps=self.max_epochs,
apply_to_entropy=config.get('schedule_entropy', False),
start_entropy_coef=config.get('entropy_coef'))

if self.max_epochs == -1 and self.max_frames != -1:
print("Max epochs and max frames are not set. Linear learning rate schedule can't be used, switching to the contstant (identity) one.")
self.scheduler = schedulers.IdentityScheduler()
else:
use_epochs = True
max_steps = self.max_epochs

if self.max_epochs == -1:
use_epochs = False
max_steps = self.max_frames

self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']),
max_steps = max_steps,
use_epochs = use_epochs,
apply_to_entropy = config.get('schedule_entropy', False),
start_entropy_coef = config.get('entropy_coef'))
else:
self.scheduler = schedulers.IdentityScheduler()

Expand Down Expand Up @@ -556,9 +591,10 @@ def get_full_state_weights(self):

def set_full_state_weights(self, weights):
self.set_weights(weights)
self.epoch_num = weights['epoch']
self.epoch_num = weights['epoch'] # frames as well?
if self.has_central_value:
self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])

self.optimizer.load_state_dict(weights['optimizer'])
self.frame = weights.get('frame', 0)
self.last_mean_rewards = weights.get('last_mean_rewards', -100500)
Expand Down Expand Up @@ -757,8 +793,10 @@ def play_steps_rnn(self):


class DiscreteA2CBase(A2CBase):

def __init__(self, base_name, params):
A2CBase.__init__(self, base_name, params)

batch_size = self.num_agents * self.num_actors
action_space = self.env_info['action_space']
if type(action_space) is gym.spaces.Discrete:
Expand Down Expand Up @@ -922,20 +960,18 @@ def train(self):
should_exit = False

if self.rank == 0:
self.diagnostics.epoch(self, current_epoch=epoch_num)
self.diagnostics.epoch(self, current_epoch = epoch_num)
scaled_time = self.num_agents * sum_time
scaled_play_time = self.num_agents * play_time

frame = self.frame // self.num_agents

if self.print_stats:
step_time = max(step_time, 1e-6)
fps_step = curr_frames / step_time
fps_step_inference = curr_frames / scaled_play_time
fps_total = curr_frames / scaled_time
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}')
print_statistics(self.print_stats, curr_frames, step_time, scaled_play_time, scaled_time,
epoch_num, self.max_epochs, frame, self.max_frames)

self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
self.write_stats(total_time, epoch_num, step_time, play_time, update_time,
a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame,
scaled_time, scaled_play_time, curr_frames)

self.algo_observer.after_print_stats(frame, epoch_num, total_time)

Expand Down Expand Up @@ -971,31 +1007,46 @@ def train(self):

if 'score_to_win' in self.config:
if self.last_mean_rewards > self.config['score_to_win']:
print('Network won!')
print('Maximum reward achieved. Network won!')
self.save(os.path.join(self.nn_dir, checkpoint_name))
should_exit = True

if epoch_num >= self.max_epochs or self.frame >= self.max_frames:
if epoch_num >= self.max_epochs and self.max_epochs != -1:
if self.game_rewards.current_size == 0:
print('WARNING: Max epochs reached before any env terminated at least once')
mean_rewards = -np.inf

self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards)))
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_ep_' + str(epoch_num) \
+ '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX EPOCHS NUM!')
should_exit = True

if self.frame >= self.max_frames and self.max_frames != -1:
if self.game_rewards.current_size == 0:
print('WARNING: Max frames reached before any env terminated at least once')
mean_rewards = -np.inf

self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_frame_' + str(self.frame) \
+ '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX FRAMES NUM!')
should_exit = True

update_time = 0

if self.multi_gpu:
should_exit_t = torch.tensor(should_exit, device=self.device).float()
dist.broadcast(should_exit_t, 0)
should_exit = should_exit_t.bool().item()

if should_exit:
return self.last_mean_rewards, epoch_num


class ContinuousA2CBase(A2CBase):

def __init__(self, base_name, params):
A2CBase.__init__(self, base_name, params)

self.is_discrete = False
action_space = self.env_info['action_space']
self.actions_num = action_space.shape[0]
Expand Down Expand Up @@ -1178,21 +1229,20 @@ def train(self):
should_exit = False

if self.rank == 0:
self.diagnostics.epoch(self, current_epoch=epoch_num)
self.diagnostics.epoch(self, current_epoch = epoch_num)
# do we need scaled_time?
scaled_time = self.num_agents * sum_time
scaled_play_time = self.num_agents * play_time
curr_frames = self.curr_frames * self.rank_size if self.multi_gpu else self.curr_frames
self.frame += curr_frames

if self.print_stats:
step_time = max(step_time, 1e-6)
fps_step = curr_frames / step_time
fps_step_inference = curr_frames / scaled_play_time
fps_total = curr_frames / scaled_time
print(f'fps step: {fps_step:.0f} fps step and policy inference: {fps_step_inference:.0f} fps total: {fps_total:.0f} epoch: {epoch_num}/{self.max_epochs} frame: {self.frame}/{self.max_frames}')
print_statistics(self.print_stats, curr_frames, step_time, scaled_play_time, scaled_time,
epoch_num, self.max_epochs, frame, self.max_frames)

self.write_stats(total_time, epoch_num, step_time, play_time, update_time,
a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame,
scaled_time, scaled_play_time, curr_frames)

self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
if len(b_losses) > 0:
self.writer.add_scalar('losses/bounds_loss', torch_ext.mean_list(b_losses).item(), frame)

Expand Down Expand Up @@ -1230,18 +1280,30 @@ def train(self):

if 'score_to_win' in self.config:
if self.last_mean_rewards > self.config['score_to_win']:
print('Network won!')
print('Maximum reward achieved. Network won!')
self.save(os.path.join(self.nn_dir, checkpoint_name))
should_exit = True

if epoch_num >= self.max_epochs or self.frame >= self.max_frames:
if epoch_num >= self.max_epochs and self.max_epochs != -1:
if self.game_rewards.current_size == 0:
print('WARNING: Max epochs reached before any env terminated at least once')
mean_rewards = -np.inf
self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + 'ep' + str(epoch_num) + 'rew' + str(mean_rewards).replace('[', '_').replace(']', '_')))

self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_ep_' + str(epoch_num) \
+ '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX EPOCHS NUM!')
should_exit = True

if self.frame >= self.max_frames and self.max_frames != -1:
if self.game_rewards.current_size == 0:
print('WARNING: Max frames reached before any env terminated at least once')
mean_rewards = -np.inf

self.save(os.path.join(self.nn_dir, 'last_' + self.config['name'] + '_frame_' + str(self.frame) \
+ '_rew_' + str(mean_rewards).replace('[', '_').replace(']', '_')))
print('MAX FRAMES NUM!')
should_exit = True

update_time = 0

if self.multi_gpu:
Expand Down
Loading