Skip to content

Commit

Permalink
use a more sensible default argument in favor of apache#24
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Nov 13, 2015
1 parent 8829912 commit 5da45f3
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ end
:param Bool overwrite: an :class:`Executor` is initialized the first time predict is called. The memory
allocation of the :class:`Executor` depends on the mini-batch size of the test
data provider. If you call predict twice with data provider of the same batch-size,
then the executor can be re-used. Otherwise, if ``overwrite`` is false (default),
an error will be raised; if ``overwrite`` is set to true, a new :class:`Executor`
will be created to replace the old one.
then the executor can be potentially be re-used. So, if ``overwrite`` is false,
we will try to re-use, and raise an error if batch-size changed. If ``overwrite``
is true (the default), a new :class:`Executor` will be created to replace the old one.
.. note::
Expand All @@ -172,12 +172,19 @@ end
For the same reason, currently prediction will only use the first device even if multiple devices are
provided to construct the model.
.. note::
If you perform further after prediction. The weights are not automatically synchronized if ``overwrite``
is set to false and the old predictor is re-used. In this case
setting ``overwrite`` to true (the default) will re-initialize the predictor the next time you call
predict and synchronize the weights again.
:seealso: :func:`train`, :func:`fit`, :func:`init_model`, :func:`load_checkpoint`
=#
function predict(callback :: Function, self :: FeedForward, data :: AbstractDataProvider; overwrite :: Bool = false)
function predict(callback :: Function, self :: FeedForward, data :: AbstractDataProvider; overwrite :: Bool = true)
predict(self, data; overwrite = overwrite, callback=callback)
end
function predict(self :: FeedForward, data :: AbstractDataProvider; overwrite::Bool=false, callback::Union{Function,Void}=nothing)
function predict(self :: FeedForward, data :: AbstractDataProvider; overwrite::Bool=true, callback::Union{Function,Void}=nothing)
data_shapes = provide_data(data)
data_names = [x[1] for x in data_shapes]
_setup_predictor(self, overwrite; data_shapes...)
Expand Down

0 comments on commit 5da45f3

Please sign in to comment.