Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions docs/sources/source/tutorials/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ tb_writer should also be defined.
def my_get_tb_values(tensors):
return [("Train_Loss", tensors[0])]

log_to_tb_func() takes two arguments: the
log_to_tb_func() takes three arguments: the
`tensorboardX.SummaryWriter <https://tensorboardx.readthedocs.io/en/latest/tensorboard.html>`_
and a list
of evaluated tensors. The user can then use the SummaryWriter class to add
images, audio, and more. For example:
, a list of evaluated tensors, and the current step. The user can then use the
SummaryWriter class to add images, audio, and more. For example:

.. code-block:: python

def log_to_tb_func(swriter, tensors):
swriter.add_scalar("Train_Loss", tensors[0])
swriter.add_audio("Train_Sample", tensors[1][0])
def log_to_tb_func(swriter, tensors, step):
swriter.add_scalar("Train_Loss", tensors[0], step)
swriter.add_audio("Train_Sample", tensors[1][0], step)

SimpleLossLoggerCallback can be constructed as follows:

Expand Down Expand Up @@ -133,9 +132,21 @@ tensors from values_dict to global_var_dict as global_var_dict is saved
between batches and passed to the final user_epochs_done_callback function.

user_epochs_done_callback is a function that accepts global_var_dict. It's job
is to log relevant information to the screen such as the evaluation loss. It
is to log relevant information to the screen such as the evaluation loss.

For simple logging of scalar values to tensorboard, user_epochs_done_callback
should return a dictionary with strings as keys and scalar tensors as values.
This tag -> value dictionary will be parsed and each element will be logged
to tensorboard if a tensorboard writter object is declared.

To enable more complex tensorboard logging such as images or audio,
EvaluatorCallback must be passed tb_writer_func at initialization. This
function must accept a
`tensorboardX.SummaryWriter <https://tensorboardx.readthedocs.io/en/latest/tensorboard.html>`_
, whatever is returned from user_epochs_done_callback, and the current step.
We recommend for user_epochs_done_callback to simply return the global_var_dict
for tb_writer_func to consume. The user must log all data of interest inside
tb_writer_func including scalars that would otherwise be logged if
tb_writer_func was not passed to EvaluatorCallback.

For an example, please see the scripts inside <nemo_dir>/examples.
10 changes: 7 additions & 3 deletions nemo/nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,13 @@ def _eval(self, tensors_2_evaluate, callback, step, verbose=False):
vals_to_log = callback.user_done_callback(
callback._global_var_dict)
# log results to Tensorboard
if vals_to_log is not None and callback._swriter is not None:
for key, val in vals_to_log.items():
callback._swriter.add_scalar(key, val, step)
if vals_to_log is not None and callback.swriter is not None:
if callback.tb_writer_func is not None:
callback.tb_writer_func(
callback.swriter, vals_to_log, step)
else:
for key, val in vals_to_log.items():
callback.swriter.add_scalar(key, val, step)

def _infer(self, tensors_to_return, step, verbose=False):
"""
Expand Down
13 changes: 12 additions & 1 deletion nemo/nemo/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def on_iteration_end(self):
value = value.item()
self._swriter.add_scalar(name, value, step)
if self._log_to_tb_func:
self._log_to_tb_func(self._swriter, tensor_values)
self._log_to_tb_func(
self._swriter, tensor_values, step)
run_time = time.time() - self._last_iter_start
self._swriter.add_scalar('misc/step_time', run_time, step)
run_time = time.time() - self._last_iter_start
Expand Down Expand Up @@ -352,6 +353,7 @@ def __init__(
user_iter_callback,
user_epochs_done_callback,
tb_writer=None,
tb_writer_func=None,
eval_step=1,
eval_epoch=None,
):
Expand All @@ -369,6 +371,7 @@ def __init__(
super().__init__()
self._eval_tensors = eval_tensors
self._swriter = tb_writer
self._tb_writer_func = tb_writer_func
self._eval_frequency = eval_step
# will be passed to callbacks below
self._global_var_dict = {}
Expand All @@ -381,6 +384,14 @@ def __init__(
def eval_tensors(self):
return self._eval_tensors

@property
def tb_writer_func(self):
return self._tb_writer_func

@property
def swriter(self):
return self._swriter

def on_epoch_end(self):
pass

Expand Down