Skip to content

Commit

Permalink
docs for built-in callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 27, 2015
1 parent 1b62ddf commit d19f28f
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 3 deletions.
65 changes: 65 additions & 0 deletions docs/api/callback.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,68 @@ Callbacks in training




.. function:: every_n_iter(callback :: Function, n :: Int; call_on_0 = false)

A convenient function to construct a callback that runs every ``n`` mini-batches.

:param Int call_on_0: keyword argument, default false. Unless set, the callback
will **not** be run on iteration 0.

For example, the :func:`speedometer` callback is defined as

.. code-block:: julia
every_n_iter(frequency, call_on_0=true) do param :: CallbackParams
if param.curr_iter == 0
# reset timer
else
# compute and print speed
end
end
:seealso: :func:`every_n_epoch`, :func:`speedometer`.




.. function:: speedometer(; frequency=50)

Create an :class:`AbstractIterationCallback` that measure the training speed
(number of samples processed per second) every k mini-batches.

:param Int frequency: keyword argument, default 50. The frequency (number of
min-batches) to measure and report the speed.




.. function:: every_n_epoch(callback :: Function, n :: Int; call_on_0 = false)

A convenient function to construct a callback that runs every ``n`` full data-passes.

:param Int call_on_0: keyword argument, default false. Unless set, the callback
will **not** be run on epoch 0. Epoch 0 means no training has been performed
yet. This is useful if you want to inspect the randomly initialized model
that has not seen any data yet.

:seealso: :func:`every_n_iter`.




.. function:: do_checkpoint(prefix; frequency=1, save_epoch_0=false)

Create an :class:`AbstractEpochCallback` that save checkpoints of the model to disk.
The checkpoints can be loaded back later on.

:param AbstractString prefix: the prefix of the filenames to save the model. The model
architecture will be saved to prefix-symbol.json, while the weights will be saved
to prefix-0012.params, for example, for the 12-th epoch.
:param Int frequency: keyword argument, default 1. The frequency (measured in epochs) to
save checkpoints.
:param Bool save_epoch_0: keyword argument, default false. Whether we should save a
checkpoint for epoch 0 (model initialized but not seen any data yet).



4 changes: 2 additions & 2 deletions docs/api/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ a network described using the symbolic API.
:param Int n_epoch: default 10, the number of full data-passes to run.
:param AbstractDataProvider eval_data: keyword argument, default ``nothing``. The data provider for
the validation set.
:param AbstractEvalMetric eval_metric: keyword argument, default :class:`Accuracy`. The metric used
:param AbstractEvalMetric eval_metric: keyword argument, default ``Accuracy()``. The metric used
to evaluate the training performance. If ``eval_data`` is provided, the same metric is also
calculated on the validation set.
:param kvstore: keyword argument, default ``:local``. The key-value store used to synchronize gradients
and parameters when multiple devices are used for training.
:type kvstore: :class:`KVStore` or ``Base.Symbol``
:param AbstractInitializer initializer: keyword argument, default :class:`UniformInitializer(0.01)`.
:param AbstractInitializer initializer: keyword argument, default ``UniformInitializer(0.01)``.
:param Bool force_init: keyword argument, default false. By default, the random initialization using the
provided ``initializer`` will be skipped if the model weights already exists, maybe from a previous
call to :func:`train` or an explicit call to :func:`init_model` or :func:`load_checkpoint`. When
Expand Down
59 changes: 58 additions & 1 deletion src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ type IterationCallback <: AbstractIterationCallback
callback :: Function
end

#=doc
.. function:: every_n_iter(callback :: Function, n :: Int; call_on_0 = false)
A convenient function to construct a callback that runs every ``n`` mini-batches.
:param Int call_on_0: keyword argument, default false. Unless set, the callback
will **not** be run on iteration 0.
For example, the :func:`speedometer` callback is defined as
.. code-block:: julia
every_n_iter(frequency, call_on_0=true) do param :: CallbackParams
if param.curr_iter == 0
# reset timer
else
# compute and print speed
end
end
:seealso: :func:`every_n_epoch`, :func:`speedometer`.
=#
function every_n_iter(callback :: Function, n :: Int; call_on_0 :: Bool = false)
IterationCallback(n, call_on_0, callback)
end
Expand All @@ -50,11 +72,20 @@ function Base.call(cb :: IterationCallback, param :: CallbackParams)
end
end

#=doc
.. function:: speedometer(; frequency=50)
Create an :class:`AbstractIterationCallback` that measure the training speed
(number of samples processed per second) every k mini-batches.
:param Int frequency: keyword argument, default 50. The frequency (number of
min-batches) to measure and report the speed.
=#
function speedometer(;frequency::Int=50)
cl_tic = 0
every_n_iter(frequency, call_on_0=true) do param :: CallbackParams
if param.curr_iter == 0
# reset counter
# reset timer
cl_tic = time()
else
speed = frequency * param.batch_size / (time() - cl_tic)
Expand All @@ -71,6 +102,18 @@ type EpochCallback <: AbstractEpochCallback
callback :: Function
end

#=doc
.. function:: every_n_epoch(callback :: Function, n :: Int; call_on_0 = false)
A convenient function to construct a callback that runs every ``n`` full data-passes.
:param Int call_on_0: keyword argument, default false. Unless set, the callback
will **not** be run on epoch 0. Epoch 0 means no training has been performed
yet. This is useful if you want to inspect the randomly initialized model
that has not seen any data yet.
:seealso: :func:`every_n_iter`.
=#
function every_n_epoch(callback :: Function, n :: Int; call_on_0 :: Bool = false)
EpochCallback(n, call_on_0, callback)
end
Expand All @@ -84,6 +127,20 @@ function Base.call(cb :: EpochCallback, model :: Any, param :: CallbackParams)
end
end

#=doc
.. function:: do_checkpoint(prefix; frequency=1, save_epoch_0=false)
Create an :class:`AbstractEpochCallback` that save checkpoints of the model to disk.
The checkpoints can be loaded back later on.
:param AbstractString prefix: the prefix of the filenames to save the model. The model
architecture will be saved to prefix-symbol.json, while the weights will be saved
to prefix-0012.params, for example, for the 12-th epoch.
:param Int frequency: keyword argument, default 1. The frequency (measured in epochs) to
save checkpoints.
:param Bool save_epoch_0: keyword argument, default false. Whether we should save a
checkpoint for epoch 0 (model initialized but not seen any data yet).
=#
function do_checkpoint(prefix::AbstractString; frequency::Int=1, save_epoch_0=false)
mkpath(dirname(prefix))
every_n_epoch(frequency, call_on_0=save_epoch_0) do model, param
Expand Down

0 comments on commit d19f28f

Please sign in to comment.