Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parametric NDArray #331

Merged
merged 40 commits into from
Dec 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b3bf1fc
wip
iblislin Nov 19, 2017
e0f73bc
ndarray: add outer constrcutor for AbstractArray
iblislin Nov 19, 2017
e43b622
ndarray: refine copy
iblislin Nov 19, 2017
2640ccc
ndarray: refine copy!
iblislin Nov 19, 2017
7ff7a38
ndarray: refine convert
iblislin Nov 19, 2017
886b3ad
ndarray: refine add_to!
iblislin Nov 19, 2017
a7e4d46
ndarray: refine sub_from!
iblislin Nov 19, 2017
3929a59
ndarray: refine mul_to!
iblislin Nov 19, 2017
ca1d69b
ndarray: refine div_from!
iblislin Nov 19, 2017
2df8321
ndarray: refine rdiv_from!
iblislin Nov 19, 2017
8256b26
ndarray: refine _wait_to_read/_wait_to_write
iblislin Nov 19, 2017
3c68e9a
ndarray: refine is_shared
iblislin Nov 19, 2017
bd700a5
ndarray: refine save
iblislin Nov 19, 2017
a648fa3
ndarray: refine dot
iblislin Nov 19, 2017
0d7faf1
ndarray: VecOfNDArray
iblislin Nov 19, 2017
bf21a7f
executor: refine backward
iblislin Nov 19, 2017
6bf408d
ndarray: refine empty
iblislin Nov 19, 2017
a53ea28
executor: refine bind
iblislin Nov 20, 2017
42dbee4
refine callback
iblislin Nov 21, 2017
fdeabba
refine executor
iblislin Nov 21, 2017
8bde97e
refine kvstore
iblislin Nov 21, 2017
658caeb
refine model.jl
iblislin Nov 21, 2017
6ccb713
fix mnist-test
iblislin Nov 21, 2017
12b36a0
metrics
iblislin Nov 24, 2017
4956d0c
io
iblislin Nov 24, 2017
16bd7f2
model
iblislin Nov 24, 2017
c07a6a1
typo
iblislin Nov 24, 2017
4177125
executor
iblislin Nov 24, 2017
4bbef45
io
iblislin Nov 24, 2017
f867d92
MSE
iblislin Nov 25, 2017
dce2a4e
style
iblislin Nov 25, 2017
0fce4f6
refine copy_params_from
iblislin Nov 27, 2017
2777e46
Merge branch 'master' into ib/param-nd
iblislin Nov 27, 2017
7daa617
Merge branch 'master' into ib/param-nd
iblislin Nov 27, 2017
e5417e0
Merge branch 'master' into ib/param-nd
iblislin Nov 28, 2017
8a43810
io: style stuff
iblislin Nov 29, 2017
cbd47b6
ndarray: fix _remap
iblislin Nov 29, 2017
fe8c251
io: style stuff
iblislin Nov 29, 2017
f98a132
kvstore
iblislin Nov 29, 2017
1571255
model: style stuff
iblislin Nov 29, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end

See also [`every_n_epoch`](@ref) and [`speedometer`](@ref).
"""
function every_n_batch(callback :: Function, n :: Int; call_on_0 :: Bool = false)
function every_n_batch(callback::Function, n::Int; call_on_0::Bool = false)
BatchCallback(n, call_on_0, callback)
end
function (cb :: BatchCallback)(state :: OptimizationState)
Expand All @@ -62,7 +62,7 @@ function (cb :: BatchCallback)(state :: OptimizationState)
end

"""
speedometer(; frequency=50)
speedometer(;frequency=50)

Create an `AbstractBatchCallback` that measure the training speed
(number of samples processed per second) every k mini-batches.
Expand All @@ -71,9 +71,9 @@ Create an `AbstractBatchCallback` that measure the training speed
* `frequency::Int`: keyword argument, default 50. The frequency (number of
min-batches) to measure and report the speed.
"""
function speedometer(;frequency::Int=50)
function speedometer(;frequency::Int = 50)
cl_tic = 0
every_n_batch(frequency, call_on_0=true) do state :: OptimizationState
every_n_batch(frequency, call_on_0 = true) do state::OptimizationState
if state.curr_batch == 0
# reset timer
cl_tic = time()
Expand Down Expand Up @@ -104,10 +104,11 @@ A convenient function to construct a callback that runs every `n` full data-pass

See also [`every_n_batch`](@ref).
"""
function every_n_epoch(callback :: Function, n :: Int; call_on_0 :: Bool = false)
every_n_epoch(callback::Function, n::Int; call_on_0::Bool = false) =
EpochCallback(n, call_on_0, callback)
end
function (cb :: EpochCallback)(model :: Any, state :: OptimizationState, metric :: Vector{Tuple{Base.Symbol, T}}) where T<:Real

function (cb::EpochCallback)(model::Any, state::OptimizationState,
metric::Vector{Tuple{Symbol, T}}) where T<:Real
if state.curr_epoch == 0
if cb.call_on_0
cb.callback(model, state, metric)
Expand All @@ -124,15 +125,17 @@ Create an `AbstractEpochCallback` that save checkpoints of the model to disk.
The checkpoints can be loaded back later on.

# Arguments
* `prefix::AbstractString`: the prefix of the filenames to save the model. The model
architecture will be saved to prefix-symbol.json, while the weights will be saved
to prefix-0012.params, for example, for the 12-th epoch.
* `frequency::Int`: keyword argument, default 1. The frequency (measured in epochs) to
save checkpoints.
* `prefix::AbstractString`: the prefix of the filenames to save the model.
The model architecture will be saved to prefix-symbol.json,
while the weights will be saved to prefix-0012.params,
for example, for the 12-th epoch.
* `frequency::Int`: keyword argument, default is 1.
The frequency (measured in epochs) to save checkpoints.
* `save_epoch_0::Bool`: keyword argument, default false. Whether we should save a
checkpoint for epoch 0 (model initialized but not seen any data yet).
checkpoint for epoch 0 (model initialized but not seen any data yet).
"""
function do_checkpoint(prefix::AbstractString; frequency::Int=1, save_epoch_0=false)
function do_checkpoint(prefix::AbstractString;
frequency::Int = 1, save_epoch_0::Bool = false)
mkpath(dirname(prefix))
every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state, metric
save_checkpoint(model, prefix, state)
Expand Down
113 changes: 52 additions & 61 deletions src/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,49 @@ be carried out with an executor.
mutable struct Executor
handle :: MX_ExecutorHandle
symbol :: SymbolicNode
arg_arrays :: Vector{NDArray}
grad_arrays :: Vector{Union{Void,NDArray}}
aux_arrays :: Vector{NDArray}
outputs :: Vector{NDArray}
arg_dict :: Dict{Base.Symbol, NDArray}
aux_dict :: Dict{Base.Symbol, NDArray}
arg_arrays :: VecOfNDArray
grad_arrays :: Vector{Union{Void,<:NDArray}}
aux_arrays :: VecOfNDArray
outputs :: VecOfNDArray
arg_dict :: Dict{Symbol}
aux_dict :: Dict{Symbol}
end
function Executor(hdr :: MX_ExecutorHandle, symbol :: SymbolicNode,
arg_arrays :: Vector{NDArray}, grad_arrays :: Vector{Union{Void,NDArray}},
aux_arrays :: Vector{NDArray})

function Executor(hdl::MX_ExecutorHandle, sym::SymbolicNode,
arg_arrays::VecOfNDArray, grad_arrays::AbstractVector,
aux_arrays::VecOfNDArray)
# get output arrays
ref_size = Ref{MX_uint}(0)
ref_hdrs = Ref{Ptr{MX_handle}}(0)
ref_hdls = Ref{Ptr{MX_handle}}(C_NULL)
@mxcall(:MXExecutorOutputs, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_handle}}),
hdr, ref_size, ref_hdrs)
out_hdrs = unsafe_wrap(Array, ref_hdrs[], ref_size[])
hdl, ref_size, ref_hdls)
out_hdrs = unsafe_wrap(Array, ref_hdls[], ref_size[])
out_arrays = [NDArray(MX_NDArrayHandle(x)) for x in out_hdrs]

arg_names = list_arguments(symbol)
arg_names = list_arguments(sym)
@assert(length(arg_names) == length(unique(arg_names)), "Duplicated names in arguments: $arg_names")
arg_dict = Dict{Base.Symbol,NDArray}(zip(arg_names, arg_arrays))
arg_dict = Dict(zip(arg_names, arg_arrays))

aux_names = list_auxiliary_states(symbol)
aux_names = list_auxiliary_states(sym)
@assert(length(aux_names) == length(unique(aux_names)), "Duplicated names in auxiliary states: $aux_names")
aux_dict = Dict{Base.Symbol,NDArray}(zip(aux_names, aux_arrays))
aux_dict = Dict(zip(aux_names, aux_arrays))

Executor(hdr, symbol, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict)
Executor(hdl, sym, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict)
end

function Base.unsafe_convert(::Type{MX_handle}, obj::Executor)
Base.unsafe_convert(::Type{MX_handle}, obj::Executor) =
Base.unsafe_convert(MX_handle, obj.handle)
end
Base.convert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj)
Base.cconvert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj)

function _get_ndarray_inputs(arg_key::AbstractString, args::Vector{NDArray}, arg_names::Vector{Base.Symbol}, allow_missing::Bool)
function _get_ndarray_inputs(arg_key::AbstractString, args::VecOfNDArray,
arg_names::Vector{Symbol}, allow_missing::Bool)
@assert(length(args) == length(arg_names), "Length of $arg_key does not match number of arguments")
return (MX_handle[args...], args)
end
function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Base.Symbol,NDArray}, arg_names::Vector{Base.Symbol}, allow_missing::Bool)

function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Symbol},
arg_names::Vector{Symbol}, allow_missing::Bool)
args_vec = map(arg_names) do name
arr = get(args, name, nothing)
if !allow_missing
Expand Down Expand Up @@ -75,16 +78,16 @@ Create an `Executor` by binding a `SymbolicNode` to concrete `NDArray`.
* `ctx::Context`: the context on which the computation should run.
* `args`: either a list of `NDArray` or a dictionary of name-array pairs. Concrete
arrays for all the inputs in the network architecture. The inputs typically include
network parameters (weights, bias, filters, etc.), data and labels. See [`list_arguments`](@ref)
and [`infer_shape`](@ref).
* `args_grad`:
* `aux_states`:
* `grad_req`:
network parameters (weights, bias, filters, etc.), data and labels.
See [`list_arguments`](@ref) and [`infer_shape`](@ref).
* `args_grad`: a `Vector` of `NDArray` or a `Dict` contains `NDArray`
* `aux_states`: a `Vector` of `NDArray` or a `Dict` contains `NDArray`
* `grad_req`: single value, a `Vector` of `GRAD_REQ` or a `Dict{Symbol,GRAD_REQ}`
"""
function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}};
args_grad :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(),
aux_states :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(),
grad_req :: Union{GRAD_REQ,Vector{GRAD_REQ},Dict{Base.Symbol,GRAD_REQ}} = GRAD_WRITE)
function bind(self::SymbolicNode, ctx::Context, args;
args_grad = Dict{Symbol,NDArray}(),
aux_states = Dict{Symbol,NDArray}(),
grad_req = GRAD_WRITE)

arg_names = list_arguments(self)

Expand All @@ -97,7 +100,7 @@ function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray
elseif isa(grad_req, Vector{GRAD_REQ})
@assert(length(grad_req) == length(args))
reqs = MX_uint[grad_req...]
elseif isa(grad_req, Dict{Base.Symbol, GRAD_REQ})
elseif isa(grad_req, Dict{Symbol, GRAD_REQ})
reqs = MX_uint[get(grad_req, name, GRAD_NOP) for name in arg_names]
end

Expand All @@ -111,20 +114,16 @@ function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray
executor = Executor(MX_ExecutorHandle(ref_hdr[]), self,
args, args_grad, aux_states)
end
function bind(self :: SymbolicNode; kwargs...)

function bind(x::SymbolicNode; context::Context = cpu(), kwargs...)
kwargs = Dict(kwargs)
@assert(haskey(kwargs, :args), "Must specify args")
args = pop!(kwargs, :args)
if haskey(kwargs, :context)
context = pop!(kwargs, :context)
else
context = cpu()
end
bind(self, context, args; kwargs...)
bind(x, context, args; kwargs...)
end

function simple_bind(self :: SymbolicNode, ctx :: Context;
grad_req :: Union{GRAD_REQ, Dict{Symbol, GRAD_REQ}}=GRAD_WRITE,
function simple_bind(self::SymbolicNode, ctx::Context;
grad_req::Union{GRAD_REQ,Dict{Symbol,GRAD_REQ}} = GRAD_WRITE,
kwargs...)
arg_shapes, out_shapes, aux_shapes = infer_shape(self; kwargs...)
@assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference")
Expand Down Expand Up @@ -168,21 +167,15 @@ function forward(self::Executor; is_train::Bool = false, kwargs...)
self.outputs
end

function backward(self :: Executor)
backward(self, NDArray[])
end
function backward(self :: Executor, out_grad :: NDArray)
backward(self, [out_grad])
end
function backward(self :: Executor, out_grads :: Vector{NDArray})
out_grads = MX_handle[out_grads...]
@mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), self, length(out_grads), out_grads)
end
backward(x::Executor) = backward(x, NDArray[])
backward(x::Executor, out_grad::NDArray) = backward(x, [out_grad])
backward(x::Executor, out_grads::VecOfNDArray) =
@mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}),
x, length(out_grads), MX_handle[out_grads...])


function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray},
aux_params::Union{Void,Dict{Base.Symbol,NDArray}}=nothing;
allow_extra_params::Bool=false)
function copy_params_from(self::Executor, arg_params::Dict{Symbol},
aux_params::Dict{Symbol} = Dict{Symbol,Any}();
allow_extra_params::Bool = false)
for (name, array) in arg_params
if haskey(self.arg_dict, name)
copy!(self.arg_dict[name], array)
Expand All @@ -191,13 +184,11 @@ function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray},
end
end

if !isa(aux_params, Void)
for (name, array) in aux_params
if haskey(self.aux_dict, name)
copy!(self.aux_dict[name], array)
else
@assert(allow_extra_params, "Extra auxiliary state $name not recognized")
end
for (name, array) in aux_params
if haskey(self.aux_dict, name)
copy!(self.aux_dict[name], array)
else
@assert(allow_extra_params, "Extra auxiliary state $name not recognized")
end
end
end
Expand Down
Loading