Skip to content

Commit

Permalink
refine model.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Nov 21, 2017
1 parent 78053d7 commit 9fc72da
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ mutable struct FeedForward <: AbstractModel
arch :: SymbolicNode
ctx :: Vector{Context}

arg_params :: Dict{Symbol,<:NDArray}
aux_params :: Dict{Symbol,<:NDArray}
arg_params :: Dict{Symbol}
aux_params :: Dict{Symbol}

pred_exec :: Union{Executor, Void}

# leave the rest fields undefined
FeedForward(arch :: SymbolicNode, ctx :: Vector{Context}) = new(arch, ctx)
FeedForward(arch::SymbolicNode, ctx::Vector{Context}) = new(arch, ctx)
end

"""
Expand Down Expand Up @@ -264,7 +264,7 @@ function _init_model(self :: FeedForward, data :: AbstractDataProvider, initiali
init_model(self, initializer; overwrite=overwrite, [provide_data(data)..., provide_label(data)...]...)
end

function _create_kvstore(kv_type :: Base.Symbol, num_device :: Int, arg_params :: Dict{Base.Symbol,NDArray}, verbosity :: Int)
function _create_kvstore(kv_type::Symbol, num_device::Int, arg_params::Dict{Symbol}, verbosity :: Int)
if num_device == 1 && !ismatch(r"dist", string(kv_type))
return nothing
else
Expand All @@ -286,15 +286,15 @@ end
n_epoch :: Int = 10,
eval_data :: Union{Void, AbstractDataProvider} = nothing,
eval_metric :: AbstractEvalMetric = Accuracy(),
kvstore :: Union{Base.Symbol, KVStore} = :local,
kvstore :: Union{Symbol, KVStore} = :local,
force_init :: Bool = false,
callbacks :: Vector{AbstractCallback} = AbstractCallback[],
verbosity :: Int = 3
)

function _invoke_callbacks(self::FeedForward, callbacks::Vector{AbstractCallback},
state::OptimizationState, type_filter::Type;
metric::Vector{Tuple{Base.Symbol, T}} = Vector{Tuple{Base.Symbol, Real}}()) where T<:Real
metric::Vector{Tuple{Symbol, T}} = Vector{Tuple{Symbol, Real}}()) where T<:Real
map(callbacks) do cb
if isa(cb, type_filter)
if type_filter == AbstractEpochCallback
Expand Down Expand Up @@ -332,7 +332,7 @@ Train the `model` on `data` with the `optimizer`.
calculated on the validation set.
* `kvstore`: keyword argument, default `:local`. The key-value store used to synchronize gradients
and parameters when multiple devices are used for training.
:type kvstore: `KVStore` or `Base.Symbol`
:type kvstore: `KVStore` or `Symbol`
* `initializer::AbstractInitializer`: keyword argument, default `UniformInitializer(0.01)`.
* `force_init::Bool`: keyword argument, default false. By default, the random initialization using the
provided `initializer` will be skipped if the model weights already exists, maybe from a previous
Expand Down Expand Up @@ -362,7 +362,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra

# setup kvstore
kvstore = opts.kvstore
if isa(kvstore, Base.Symbol)
if isa(kvstore, Symbol)
opts.verbosity >= 2 && info("Creating KVStore...")
kvstore = _create_kvstore(kvstore, length(self.ctx), self.arg_params, opts.verbosity)
end
Expand Down Expand Up @@ -622,7 +622,7 @@ end
Load a mx.FeedForward model from the checkpoint *prefix*, *epoch* and optionally provide a context.
"""
function load_checkpoint(prefix :: AbstractString, epoch :: Int, ::Type{FeedForward}; context = nothing)
function load_checkpoint(prefix::AbstractString, epoch::Int, ::Type{FeedForward}; context = nothing)
arch, arg_params, aux_params = load_checkpoint(prefix, epoch)
model = FeedForward(arch, context = context)
model.arg_params = arg_params
Expand Down

0 comments on commit 9fc72da

Please sign in to comment.