Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 46 additions & 27 deletions python/ray/rllib/optimizers/async_samples_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def _init(self,
num_sgd_iter=1,
minibatch_buffer_size=1,
_fake_gpus=False):
self.learning_started = False
self.train_batch_size = train_batch_size
self.sample_batch_size = sample_batch_size
self.broadcast_interval = broadcast_interval

self._stats_start_time = time.time()
self._last_stats_time = {}
self._last_stats_sum = {}

if num_gpus > 1 or num_data_loader_buffers > 1:
logger.info(
"Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format(
Expand Down Expand Up @@ -82,10 +85,12 @@ def _init(self,
assert len(self.remote_evaluators) > 0

# Stats
self.timers = {k: TimerStat() for k in ["train", "sample"]}
self._optimizer_step_timer = TimerStat()
self.num_weight_syncs = 0
self.num_replayed = 0
self.learning_started = False
self._stats_start_time = time.time()
self._last_stats_time = {}
self._last_stats_val = {}

# Kick off async background sampling
self.sample_tasks = TaskPool()
Expand All @@ -108,19 +113,36 @@ def _init(self,
self.replay_buffer_num_slots = replay_buffer_num_slots
self.replay_batches = []

def add_stat_val(self, key, val):
if key not in self._last_stats_sum:
self._last_stats_sum[key] = 0
self._last_stats_time[key] = self._stats_start_time
self._last_stats_sum[key] += val

def get_mean_stats_and_reset(self):
now = time.time()
mean_stats = {
key: round(val / (now - self._last_stats_time[key]), 3)
for key, val in self._last_stats_sum.items()
}

for key in self._last_stats_sum.keys():
self._last_stats_sum[key] = 0
self._last_stats_time[key] = time.time()

return mean_stats

@override(PolicyOptimizer)
def step(self):
assert self.learner.is_alive()
start = time.time()
sample_timesteps, train_timesteps = self._step()
time_delta = time.time() - start
self.timers["sample"].push(time_delta)
self.timers["sample"].push_units_processed(sample_timesteps)
with self._optimizer_step_timer:
sample_timesteps, train_timesteps = self._step()

if sample_timesteps > 0:
self.add_stat_val("sample_throughput", sample_timesteps)
if train_timesteps > 0:
self.learning_started = True
if self.learning_started:
self.timers["train"].push(time_delta)
self.timers["train"].push_units_processed(train_timesteps)
self.add_stat_val("train_throughput", train_timesteps)

self.num_steps_sampled += sample_timesteps
self.num_steps_trained += train_timesteps

Expand All @@ -130,27 +152,24 @@ def stop(self):

@override(PolicyOptimizer)
def stats(self):
def timer_to_ms(timer):
return round(1000 * timer.mean, 3)

timing = {
"{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
for k in self.timers
"optimizer_step_time_ms": timer_to_ms(self._optimizer_step_timer),
"learner_grad_time_ms": timer_to_ms(self.learner.grad_timer),
"learner_load_time_ms": timer_to_ms(self.learner.load_timer),
"learner_load_wait_time_ms": timer_to_ms(
self.learner.load_wait_timer),
"learner_dequeue_time_ms": timer_to_ms(self.learner.queue_timer),
}
timing["learner_grad_time_ms"] = round(
1000 * self.learner.grad_timer.mean, 3)
timing["learner_load_time_ms"] = round(
1000 * self.learner.load_timer.mean, 3)
timing["learner_load_wait_time_ms"] = round(
1000 * self.learner.load_wait_timer.mean, 3)
timing["learner_dequeue_time_ms"] = round(
1000 * self.learner.queue_timer.mean, 3)
stats = {
"sample_throughput": round(self.timers["sample"].mean_throughput,
3),
"train_throughput": round(self.timers["train"].mean_throughput, 3),
stats = dict({
"num_weight_syncs": self.num_weight_syncs,
"num_steps_replayed": self.num_replayed,
"timing_breakdown": timing,
"learner_queue": self.learner.learner_queue_size.stats(),
}
}, **self.get_mean_stats_and_reset())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax is needed for python 2.

self._last_stats_val.clear()
if self.learner.stats:
stats["learner"] = self.learner.stats
return dict(PolicyOptimizer.stats(self), **stats)
Expand Down