From b3bf1fc985bcb449d3b1f9ee8c9270c7b05c3661 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 14:37:16 +0800 Subject: [PATCH 01/37] wip --- src/executor.jl | 48 +++++++----- src/io.jl | 45 ++++++----- src/kvstore.jl | 21 ++--- src/model.jl | 4 +- src/ndarray.jl | 161 ++++++++++++++++++--------------------- test/unittest/ndarray.jl | 6 +- 6 files changed, 141 insertions(+), 144 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index 26da87c69..adc079c9c 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -5,19 +5,25 @@ An executor is a realization of a symbolic architecture defined by a `SymbolicNo The actual forward and backward computation specified by the network architecture can be carried out with an executor. """ -mutable struct Executor +const VecOfNDArray = AbstractVector{<:NDArray} +mutable struct Executor{A<:VecOfNDArray, + B<:VecOfNDArray, + G<:AbstractVector{<:Union{Void,NDArray}}, + O<:VecOfNDArray, + D<:Dict{Symbol,<:NDArray}, + E<:Dict{Symbol,<:NDArray}} handle :: MX_ExecutorHandle symbol :: SymbolicNode - arg_arrays :: Vector{NDArray} - grad_arrays :: Vector{Union{Void,NDArray}} - aux_arrays :: Vector{NDArray} - outputs :: Vector{NDArray} - arg_dict :: Dict{Base.Symbol, NDArray} - aux_dict :: Dict{Base.Symbol, NDArray} + arg_arrays :: A + grad_arrays :: G + aux_arrays :: B + outputs :: O + arg_dict :: D + aux_dict :: E end function Executor(hdr :: MX_ExecutorHandle, symbol :: SymbolicNode, - arg_arrays :: Vector{NDArray}, grad_arrays :: Vector{Union{Void,NDArray}}, - aux_arrays :: Vector{NDArray}) + arg_arrays :: Vector{<:NDArray}, grad_arrays :: Vector{<:Union{Void,NDArray}}, + aux_arrays :: Vector{<:NDArray}) # get output arrays ref_size = Ref{MX_uint}(0) ref_hdrs = Ref{Ptr{MX_handle}}(0) @@ -28,11 +34,11 @@ function Executor(hdr :: MX_ExecutorHandle, symbol :: SymbolicNode, arg_names = list_arguments(symbol) @assert(length(arg_names) == length(unique(arg_names)), "Duplicated names in arguments: $arg_names") - arg_dict = Dict{Base.Symbol,NDArray}(zip(arg_names, arg_arrays)) + arg_dict = Dict(zip(arg_names, arg_arrays)) aux_names = list_auxiliary_states(symbol) @assert(length(aux_names) == length(unique(aux_names)), "Duplicated names in auxiliary states: $aux_names") - aux_dict = Dict{Base.Symbol,NDArray}(zip(aux_names, aux_arrays)) + aux_dict = Dict(zip(aux_names, aux_arrays)) Executor(hdr, symbol, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict) end @@ -43,11 +49,13 @@ end Base.convert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj) Base.cconvert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj) -function _get_ndarray_inputs(arg_key::AbstractString, args::Vector{NDArray}, arg_names::Vector{Base.Symbol}, allow_missing::Bool) +function _get_ndarray_inputs(arg_key::AbstractString, args::Vector{<:NDArray}, + arg_names::Vector{Symbol}, allow_missing::Bool) @assert(length(args) == length(arg_names), "Length of $arg_key does not match number of arguments") return (MX_handle[args...], args) end -function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Base.Symbol,NDArray}, arg_names::Vector{Base.Symbol}, allow_missing::Bool) +function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Symbol,<:NDArray}, + arg_names::Vector{Symbol}, allow_missing::Bool) args_vec = map(arg_names) do name arr = get(args, name, nothing) if !allow_missing @@ -81,9 +89,9 @@ Create an `Executor` by binding a `SymbolicNode` to concrete `NDArray`. * `aux_states`: * `grad_req`: """ -function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}}; - args_grad :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(), - aux_states :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(), +function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{<:NDArray},Dict{Symbol,<:NDArray}}; + args_grad :: Union{Vector{<:NDArray},Dict{Symbol,<:NDArray}} = Dict{Symbol,NDArray}(), + aux_states :: Union{Vector{<:NDArray},Dict{Symbol,<:NDArray}} = Dict{Symbol,NDArray}(), grad_req :: Union{GRAD_REQ,Vector{GRAD_REQ},Dict{Base.Symbol,GRAD_REQ}} = GRAD_WRITE) arg_names = list_arguments(self) @@ -97,7 +105,7 @@ function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray elseif isa(grad_req, Vector{GRAD_REQ}) @assert(length(grad_req) == length(args)) reqs = MX_uint[grad_req...] - elseif isa(grad_req, Dict{Base.Symbol, GRAD_REQ}) + elseif isa(grad_req, Dict{Symbol, GRAD_REQ}) reqs = MX_uint[get(grad_req, name, GRAD_NOP) for name in arg_names] end @@ -174,14 +182,14 @@ end function backward(self :: Executor, out_grad :: NDArray) backward(self, [out_grad]) end -function backward(self :: Executor, out_grads :: Vector{NDArray}) +function backward(self :: Executor, out_grads :: Vector{<:NDArray}) out_grads = MX_handle[out_grads...] @mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), self, length(out_grads), out_grads) end -function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray}, - aux_params::Union{Void,Dict{Base.Symbol,NDArray}}=nothing; +function copy_params_from(self::Executor, arg_params::Dict{Symbol,<:NDArray}, + aux_params::Union{Void,Dict{Symbol,<:NDArray}}=nothing; allow_extra_params::Bool=false) for (name, array) in arg_params if haskey(self.arg_dict, name) diff --git a/src/io.jl b/src/io.jl index 2ba0bf78a..53101590d 100644 --- a/src/io.jl +++ b/src/io.jl @@ -113,9 +113,9 @@ function get_label end A basic subclass of `AbstractDataBatch`, that implement the interface by accessing member fields. """ -mutable struct DataBatch <: AbstractDataBatch - data :: Vector{NDArray} - label :: Vector{NDArray} +mutable struct DataBatch{V<:AbstractVector{<:NDArray}} <: AbstractDataBatch + data :: V + label :: V count :: Int end count_samples(batch :: DataBatch) = batch.count @@ -127,10 +127,10 @@ get_label(::Provider, batch :: DataBatch) where {Provider<:AbstractDataProvider} A alias type of `Tuple{UnitRange{Int},NDArray}`. """ -const SlicedNDArray = Tuple{UnitRange{Int},NDArray} +const SlicedNDArray = Tuple{UnitRange{Int},<:NDArray} function _load_general!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, - targets :: Vector{Vector{SlicedNDArray}}, loader::Function) + targets :: Vector{<:Vector{<:SlicedNDArray}}, loader::Function) data = loader(provider, batch) for (d_src, d_targets) in zip(data, targets) for (slice_idx, d_dst) in d_targets @@ -157,7 +157,7 @@ This utility function is used in data parallelization, where a mini-batch is spl and computed on several different devices. """ function load_data!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, - targets :: Vector{Vector{SlicedNDArray}}) + targets :: Vector{<:Vector{<:SlicedNDArray}}) _load_general!(provider, batch, targets, get_data) end @@ -171,16 +171,18 @@ end The same as [`load_data!`](@ref), except that this is for loading labels. """ function load_label!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, - targets :: Vector{Vector{SlicedNDArray}}) + targets :: Vector{<:Vector{<:SlicedNDArray}}) _load_general!(provider, batch, targets, get_label) end -function load_data!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, targets :: Vector{NDArray}) +function load_data!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, + targets :: Vector{<:NDArray}) for (src, dst) in zip(get_data(provider, batch), targets) copy!(dst, src) end end -function load_label!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, targets :: Vector{NDArray}) +function load_label!(provider :: AbstractDataProvider, batch :: AbstractDataBatch, + targets :: Vector{<:NDArray}) for (src, dst) in zip(get_label(provider, batch), targets) copy!(dst, src) end @@ -216,7 +218,7 @@ end eachbatch(provider::AbstractDataProvider) Allows you to perform operations on data every epoch. This is especially useful -when you need to perform real-time augmentation of the data. +when you need to perform real-time augmentation of the data. # Arguments: * `provider`: an instance of the custom DataProvider type. You must return this @@ -229,7 +231,7 @@ eachbatch(provider :: AbstractDataProvider) = provider ArrayDataProvider A convenient tool to iterate `NDArray` or Julia `Array`. - + ArrayDataProvider(data[, label]; batch_size, shuffle, data_padding, label_padding) Construct a data provider from `NDArray` or Julia Arrays. @@ -253,18 +255,18 @@ TODO: remove `data_padding` and `label_padding`, and implement rollover that cop the last or first several training samples to feed the padding. """ mutable struct ArrayDataProvider <: AbstractDataProvider - data_arrays :: Vector{Array} - data_names :: Vector{Base.Symbol} - label_arrays :: Vector{Array} - label_names :: Vector{Base.Symbol} + data_arrays + data_names :: Vector{Symbol} + label_arrays + label_names :: Vector{Symbol} batch_size :: Int sample_count :: Int shuffle :: Bool data_padding :: MX_float label_padding :: MX_float - data_batch :: Vector{NDArray} - label_batch :: Vector{NDArray} + data_batch + label_batch end # Julia's type system is sometimes very frustrating. You cannot specify a function @@ -349,7 +351,8 @@ function ArrayDataProvider(data::Any, label::Any; batch_size::Int=0, shuffle::Bo end ArrayDataProvider(data_arrays, data_names, label_arrays, label_names, batch_size, - sample_count, shuffle, data_padding, label_padding, data_batch, label_batch) + sample_count, shuffle, MX_float(data_padding), MX_float(label_padding), + data_batch, label_batch) end function provide_data(provider::ArrayDataProvider) @@ -425,8 +428,8 @@ a list of built-in data iterators. """ mutable struct MXDataProvider <: AbstractDataProvider handle :: MX_DataIterHandle - data_shape :: Vector{Tuple{Base.Symbol, Tuple}} - label_shape:: Vector{Tuple{Base.Symbol, Tuple}} + data_shape :: Vector{Tuple{Symbol, Tuple}} + label_shape:: Vector{Tuple{Symbol, Tuple}} batch_size :: Int # those two a auxiliary variables to help avoid calling reset @@ -569,7 +572,7 @@ function _define_data_iter_creator(hdr :: MX_handle) isprovider = endswith(string(iter_name), "Iter") signature = _format_signature(Int(ref_narg[]), ref_arg_names) f_desc = " " * string(iter_name) * "(" *signature * ")\n\n" - if isprovider + if isprovider f_desc *= "Can also be called with the alias `$(string(iter_name)[1:end-4] * "Provider")`.\n" end f_desc *= unsafe_string(ref_desc[]) * "\n\n" diff --git a/src/kvstore.jl b/src/kvstore.jl index 1ac56260b..2e424e6fb 100644 --- a/src/kvstore.jl +++ b/src/kvstore.jl @@ -20,7 +20,7 @@ end Base.convert(t::Type{MX_handle}, obj::KVStore) = Base.unsafe_convert(t, obj) Base.cconvert(t::Type{MX_handle}, obj::KVStore) = Base.unsafe_convert(t, obj) -function _flatten_kvlist(keys :: Vector{Int}, vals :: Vector{Vector{NDArray}}) +function _flatten_kvlist(keys :: Vector{Int}, vals :: Vector{<:Vector{<:NDArray}}) @assert length(keys) == length(vals) keys_flt = Int[] vals_flt = NDArray[] @@ -34,13 +34,13 @@ end function init!(self :: KVStore, key :: Int, val :: NDArray) init!(self, [key], [val]) end -function init!(self :: KVStore, key :: Int, vals :: Vector{NDArray}) +function init!(self :: KVStore, key :: Int, vals :: Vector{<:NDArray}) init!(self, Base.ones(Int, length(vals))*key, vals) end -function init!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{Vector{NDArray}}) +function init!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{Vector{<:NDArray}}) init!(self, _flatten_kvlist(keys, vals)...) end -function init!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{NDArray}) +function init!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{<:NDArray}) @assert length(keys) == length(vals) keys = Cint[keys...] vals = MX_handle[vals...] @@ -52,13 +52,14 @@ import Base.push! function push!(self :: KVStore, key :: Int, val :: NDArray; priority :: Int = 0) push!(self, [key], [val]; priority = priority) end -function push!(self :: KVStore, key :: Int, vals :: Vector{NDArray}; priority :: Int = 0) +function push!(self :: KVStore, key :: Int, vals :: Vector{<:NDArray}; priority :: Int = 0) push!(self, Base.ones(Int, length(vals))*key, vals; priority = priority) end -function push!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{Vector{NDArray}}; priority::Int=0) +function push!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{<:Vector{<:NDArray}}; + priority::Int=0) push!(self, _flatten_kvlist(keys, vals)...; priority = priority) end -function push!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{NDArray}; priority::Int=0) +function push!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{<:NDArray}; priority::Int=0) @assert length(keys) == length(vals) keys = Cint[keys...] vals = MX_handle[vals...] @@ -69,13 +70,13 @@ end function pull!(self :: KVStore, key :: Int, out :: NDArray; priority :: Int = 0) pull!(self, [key], [out]) end -function pull!(self :: KVStore, key :: Int, outs :: Vector{NDArray}; priority :: Int = 0) +function pull!(self :: KVStore, key :: Int, outs :: Vector{<:NDArray}; priority :: Int = 0) pull!(self, Base.ones(Int, length(outs))*key, outs; priority = priority) end -function pull!(self :: KVStore, keys :: Vector{Int}, outs :: Vector{Vector{NDArray}}; priority::Int=0) +function pull!(self :: KVStore, keys :: Vector{Int}, outs :: Vector{<:Vector{<:NDArray}}; priority::Int=0) pull!(self, _flatten_kvlist(keys, outs)...; priority = priority) end -function pull!(self :: KVStore, keys :: Vector{Int}, outs :: Vector{NDArray}; priority::Int=0) +function pull!(self :: KVStore, keys :: Vector{Int}, outs :: Vector{<:NDArray}; priority::Int=0) @assert length(keys) == length(outs) keys = Cint[keys...] outs = MX_handle[outs...] diff --git a/src/model.jl b/src/model.jl index df15e4cac..8845f65ad 100644 --- a/src/model.jl +++ b/src/model.jl @@ -18,8 +18,8 @@ mutable struct FeedForward <: AbstractModel arch :: SymbolicNode ctx :: Vector{Context} - arg_params :: Dict{Base.Symbol, NDArray} - aux_params :: Dict{Base.Symbol, NDArray} + arg_params :: Dict{Symbol,<:NDArray} + aux_params :: Dict{Symbol,<:NDArray} pred_exec :: Union{Executor, Void} diff --git a/src/ndarray.jl b/src/ndarray.jl index 9aca2f3d9..2dfa3a8cb 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -3,7 +3,7 @@ const DType = Union{Float32, Float64, Float16, UInt8, Int32, Int8, Int64} @enum TypeFlag kFloat32 kFloat64 kFloat16 kUint8 kInt32 kInt8 kInt64 const DEFAULT_DTYPE = Float32 # MSHADOW_DEFAULT_DTYPE -function toTypeFlag(:: Type{T}) where T <: DType +function toTypeFlag(T::Type{<:DType}) if T == Float32 return kFloat32 elseif T == Float64 @@ -23,7 +23,7 @@ function toTypeFlag(:: Type{T}) where T <: DType end end -function fromTypeFlag(T :: TypeFlag) +function fromTypeFlag(T::TypeFlag) if T == kFloat32 return Float32 elseif T == kFloat64 @@ -89,32 +89,32 @@ of tensor-based computation. C/C++/Python shape (100,1,28,28), while in Julia, the same piece of memory have shape (28,28,1,100). """ -mutable struct NDArray +mutable struct NDArray{T,D} handle :: MX_NDArrayHandle writable :: Bool - function NDArray(handle, writable=true) - new(handle, writable) - end + NDArray{T,D}(handle, writable = true) where {T,D} = new(handle, writable) end NDArray(x::AbstractArray{T}) where {T<:DType} = copy(collect(x), cpu()) NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu()) +NDArray(handle, writable = true) = + NDArray{eltype(handle), ndims(handle)}(handle, writable) const NDArrayOrReal = Union{NDArray, Real} @unfuse NDArray -function Base.show(io :: IO, arr :: NDArray) - println(io, "$(join(size(arr), "×")) mx.NDArray{$(eltype(arr))} @ $(context(arr)):") - Base.showarray(io, try_get_shared(arr, sync=:read), false, header=false) +function Base.show(io::IO, x::NDArray{T,N}) where {T,N} + type_ = split(string(typeof(x)), '.', limit=2)[end] + println(io, "$(join(size(x), "×")) $(type_) @ $(context(x)):") + Base.showarray(io, try_get_shared(x, sync=:read), false, header=false) end -function Base.unsafe_convert(::Type{MX_handle}, obj::NDArray) +Base.unsafe_convert(::Type{MX_handle}, obj::NDArray) = Base.unsafe_convert(MX_handle, obj.handle) -end -Base.convert(t::Type{MX_handle}, obj::NDArray) = Base.unsafe_convert(t, obj) -Base.cconvert(t::Type{MX_handle}, obj::NDArray) = Base.unsafe_convert(t, obj) +Base.convert(T::Type{MX_handle}, obj::NDArray) = Base.unsafe_convert(T, obj) +Base.cconvert(T::Type{MX_handle}, obj::NDArray) = Base.unsafe_convert(T, obj) ################################################################################ # NDArray functions exported to the users @@ -134,21 +134,15 @@ end """ - empty(DType, shape :: Tuple, ctx :: Context) - empty(DType, shape :: Tuple) + empty(DType, dims[, ctx::Context = cpu()]) + empty(DType, dims) empty(DType, dim1, dim2, ...) Allocate memory for an uninitialized `NDArray` with a specified type. """ -function empty(::Type{T}, shape :: NTuple{N, Int}) where {N,T<:DType} - empty(T, shape, cpu()) -end -function empty(:: Type{T}, shape :: NTuple{N, Int}, ctx :: Context) where {N,T<:DType} - NDArray(_ndarray_alloc(T, shape, ctx, false)) -end -function empty(:: Type{T}, shape :: Int...) where T<:DType - empty(T, shape) -end +empty(::Type{T}, dims::NTuple{N, Int}, ctx::Context = cpu()) where {N,T<:DType} = + NDArray{T, N}(_ndarray_alloc(T, dims, ctx, false)) +empty(::Type{T}, dims::Int...) where {T<:DType} = empty(T, dims) """ empty(shape :: Tuple, ctx :: Context) @@ -167,54 +161,39 @@ function empty(shape :: Int...) empty(shape) end -import Base.similar - """ - similar(arr :: NDArray) + similar(x::NDArray) -Create an `NDArray` with similar shape, data type, and context with the given one. +Create an `NDArray` with similar shape, data type, +and context with the given one. +Note that the returned `NDArray` is uninitialized. """ -function similar(arr :: NDArray) - empty(eltype(arr), size(arr), context(arr)) -end +Base.similar(x::NDArray) = empty(eltype(x), size(x), context(x)) """ - zeros(DType, shape :: Tuple, ctx :: Context) - zeros(DType, shape :: Tuple) - zeros(DType, dim1, dim2, ...) + zeros(DType, dims[, ctx::Context = cpu()]) + zeros(DType, dims...) -Create zero-ed `NDArray` with specific shape and type +Create zero-ed `NDArray` with specific shape and type. """ -function zeros(:: Type{T}, shape :: NTuple{N, Int}) where {N,T<:DType} - zeros(T, shape, cpu()) -end -function zeros(:: Type{T}, shape :: NTuple{N, Int}, ctx :: Context) where {N,T<:DType} - arr = empty(T, shape, ctx) +function zeros(::Type{T}, dims::NTuple{N, Int}, ctx::Context = cpu()) where {N,T<:DType} + arr = empty(T, dims, ctx) arr[:] = zero(T) - return arr -end -function zeros(:: Type{T}, shape :: Int...) where T<:DType - zeros(T, shape) + arr end +zeros(::Type{T}, dims::Int...) where {T<:DType} = zeros(T, dims) + """ - zeros(shape :: Tuple, ctx :: Context) - zeros(shape :: Tuple) - zeros(dim1, dim2, ...) + zeros(dims[, ctx::Context = cpu()]) + zeros(dims...) Create zero-ed `NDArray` with specific shape. """ -function zeros(shape :: NTuple{N, Int}) where N - zeros(shape, cpu()) -end -function zeros(shape :: NTuple{N, Int}, ctx :: Context) where N - arr = empty(shape, ctx) - arr[:] = 0 - return arr -end -function zeros(shape :: Int...) - zeros(shape) -end +zeros(dims::NTuple{N, Int}, ctx::Context = cpu()) where N = + zeros(MX_float, dims, ctx) + +zeros(dims::Int...) = zeros(dims) """ ones(DType, shape :: Tuple, ctx :: Context) @@ -257,11 +236,11 @@ end import Base: size, length, ndims, eltype """ - size(arr :: NDArray) - size(arr :: NDArray, dim :: Int) + size(x::NDArray) + size(x::NDArray, dim) -Get the shape of an `NDArray`. The shape is in Julia's column-major convention. See -also the notes on NDArray shapes [`NDArray`](@ref). +Get the shape of an `NDArray`. The shape is in Julia's column-major convention. +See also the notes on NDArray shapes [`NDArray`](@ref). """ function size(arr :: NDArray) ref_ndim = Ref{MX_uint}(0) @@ -275,45 +254,50 @@ function size(arr :: NDArray, dim :: Int) end """ - length(arr :: NDArray) + length(x::NDArray) Get the number of elements in an `NDArray`. """ -function length(arr :: NDArray) - prod(size(arr)) -end +length(x::NDArray) = prod(size(x)) """ - ndims(arr :: NDArray) + ndims(x::NDArray) -Get the number of dimensions of an `NDArray`. Is equivalent to `length(size(arr))`. +Get the number of dimensions of an `NDArray`. +Is equivalent to `length(size(arr))`. """ -function ndims(arr :: NDArray) - length(size(arr)) +ndims(x::NDArray) = ndims(x.handle) + +function ndims(x::MX_NDArrayHandle)::Int + ref_ndim = Ref{MX_uint}(0) + ref_shape = Ref{Ptr{MX_uint}}(0) + @mxcall(:MXNDArrayGetShape, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_uint}}), + x, ref_ndim, ref_shape) + ref_ndim[] end """ - eltype(arr :: NDArray) + eltype(x::NDArray) Get the element type of an `NDArray`. """ -function eltype(arr :: T) where T <: Union{NDArray, MX_NDArrayHandle} +function eltype(x::Union{NDArray, MX_NDArrayHandle}) dtype_ref = Ref{Cint}(0) - @mxcall(:MXNDArrayGetDType, (MX_handle, Ptr{Cint}), arr, dtype_ref) + @mxcall(:MXNDArrayGetDType, (MX_handle, Ptr{Cint}), x, dtype_ref) - if dtype_ref[] == -1 # arr->is_none() - warn("Eltype of $arr is not defined") + if dtype_ref[] == -1 # x->is_none() + warn("Eltype of $x is not defined") Base.show_backtrace(STDOUT, backtrace()) println() - return Float32 + Float32 else - return fromTypeFlag(TypeFlag(dtype_ref[])) + fromTypeFlag(TypeFlag(dtype_ref[])) end end -@inline _first(arr::NDArray) = try_get_shared(arr, sync = :read) |> first +@inline _first(x::NDArray) = try_get_shared(x, sync = :read) |> first -Base.first(arr::NDArray) = _first(arr) +Base.first(x::NDArray) = _first(x) """ slice(arr :: NDArray, start:stop) @@ -974,7 +958,7 @@ corresponding components enabled. Examples: * `hdfs://my-bucket/path/my-hdfs-ndarray` * `/path-to/my-local-ndarray` """ -function load(filename::AbstractString, ::Type{NDArray}) +function load(filename::AbstractString, ::Type{<:NDArray}) out_size = Ref{MX_uint}(0) out_hdrs = Ref{Ptr{MX_handle}}(0) out_name_size = Ref{MX_uint}(0) @@ -993,22 +977,23 @@ function load(filename::AbstractString, ::Type{NDArray}) end """ - save(filename :: AbstractString, data) + save(filename::AbstractString, data) Save NDarrays to binary file. Filename could be S3 or HDFS address, if `libmxnet` is built with corresponding support (see `load`). * `filename::String`: path to the binary file to write to. -* `data`: data to save to file. Data can be a`NDArray`, a `Vector{NDArray}`, or a `Dict{Base.Symbol, NDArray}`. +* `data`: data to save to file. Data can be a`NDArray`, a `Vector{<:NDArray}`, + or a `Dict{Symbol, <:NDArray}`. """ -function save(filename::String, data::NDArray) - save(filename, [data]) -end -function save(filename::String, data::Vector{NDArray}) +save(filename::String, data::NDArray) = save(filename, [data]) + +function save(filename::String, data::Vector{<:NDArray}) @mxcall(:MXNDArraySave, (char_p, MX_uint, Ptr{MX_handle}, char_pp), filename, length(data), MX_handle[data...], char_pp(0)) end -function save(filename::String, data::Dict{Base.Symbol,NDArray}) + +function save(filename::String, data::Dict{Symbol,<:NDArray}) names = [k for k in keys(data)] arrays = MX_handle[data[k] for k in names] names = String[string(k) for k in names] @@ -1167,7 +1152,7 @@ function _get_ndarray_function_def(name :: String) func_name = Symbol(name) func_def = quote - function $func_name(::Type{NDArray}, args::NDArray...; out=nothing, kwargs...) + function $func_name(::Type{<:NDArray}, args::NDArray...; out=nothing, kwargs...) if out != nothing output_vars = out if isa(output_vars, NDArray) diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 5217ca80f..06082e8b7 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -417,7 +417,7 @@ function test_saveload() j_array, nd_array = rand_tensors(dims) mx.save(fname, nd_array) data = mx.load(fname, mx.NDArray) - @test data isa Vector{mx.NDArray} + @test data isa Vector{<:mx.NDArray} @test length(data) == 1 @test copy(data[1]) ≈ j_array @@ -426,7 +426,7 @@ function test_saveload() nd_arrays = mx.NDArray[x[2] for x in arrays] mx.save(fname, nd_arrays) data = mx.load(fname, mx.NDArray) - @test isa(data, Vector{mx.NDArray}) + @test data isa Vector{<:mx.NDArray} @test length(data) == n_arrays for i = 1:n_arrays @test copy(data[i]) ≈ arrays[i][1] @@ -437,7 +437,7 @@ function test_saveload() dict = Dict([(n, v) for (n,v) in zip(names, nd_arrays)]) mx.save(fname, dict) data = mx.load(fname, mx.NDArray) - @test data isa Dict{Symbol, mx.NDArray} + @test data isa Dict{Symbol,<:mx.NDArray} @test length(data) == n_arrays for i = 1:n_arrays @test copy(data[names[i]]) ≈ arrays[i][1] From e0f73bc7e207db8dd6e3fdf618c35ec6738189e8 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 14:38:01 +0800 Subject: [PATCH 02/37] ndarray: add outer constrcutor for AbstractArray --- src/ndarray.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/ndarray.jl b/src/ndarray.jl index 2dfa3a8cb..e43fb78a0 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -101,6 +101,9 @@ NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu()) NDArray(handle, writable = true) = NDArray{eltype(handle), ndims(handle)}(handle, writable) +NDArray(x::AbstractArray{T}) where {T<:DType} = copy(collect(x), cpu()) +NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu()) + const NDArrayOrReal = Union{NDArray, Real} @unfuse NDArray From e43b6224d0c19ca35f19cc2e53fbd49d1f774394 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 15:45:11 +0800 Subject: [PATCH 03/37] ndarray: refine copy --- src/ndarray.jl | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index e43fb78a0..90b58f33d 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -500,22 +500,15 @@ Create a copy of an array. When no `Context` is given, create a Julia `Array`. Otherwise, create an `NDArray` on the specified context. """ # Create copy: NDArray -> Julia Array -function copy(arr :: NDArray) - j_arr = Array{eltype(arr)}(size(arr)) - copy!(j_arr, arr) -end +copy(x::NDArray{T,D}) where{T,D} = copy!(Array{T,D}(size(x)), x) # Create copy: NDArray -> NDArray in a given context -function copy(arr :: NDArray, ctx :: Context) - dst = NDArray(_ndarray_alloc(eltype(arr), size(arr), ctx, true)) - copy!(dst, arr) -end +copy(x::NDArray{T,D}, ctx::Context) where {T,D} = + copy!(NDArray{T,D}(_ndarray_alloc(T, size(x), ctx, true)), x) # Create copy: Julia Array -> NDArray in a given context -function copy(arr :: Array{T}, ctx :: Context) where T<:DType - dst = empty(T, size(arr), ctx) - copy!(dst, arr) -end +copy(x::Array{T}, ctx::Context) where {T<:DType} = + copy!(empty(T, size(x), ctx), x) """ convert(::Type{Array{T}}, arr :: NDArray) From 2640cccedd368ebb496791a467778395552777d1 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:11:07 +0800 Subject: [PATCH 04/37] ndarray: refine copy! --- src/ndarray.jl | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 90b58f33d..2abdb15da 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -450,7 +450,7 @@ import Base: copy!, copy, convert, deepcopy Copy contents of `src` into `dst`. """ -function copy!(dst :: NDArray, src :: NDArray) +function copy!(dst::NDArray, src::NDArray) @assert(dst.writable) if dst.handle == src.handle warn("Copying an NDArray to itself") @@ -461,33 +461,31 @@ function copy!(dst :: NDArray, src :: NDArray) return dst end -function copy!(dst :: Array{T}, src :: NDArray) where T<:DType - @assert T == eltype(src) +function copy!(dst::Array{T}, src::NDArray{T}) where T<:DType @assert size(dst) == size(src) @mxcall(:MXNDArraySyncCopyToCPU, (MX_handle, Ptr{Void}, Csize_t), src, pointer(dst), length(dst)) - return dst -end -function copy!(dst :: Array{T}, src :: NDArray) where T<:Real - copy!(dst, copy(src)) + dst end -function copy!(dst :: NDArray, src :: Array{T}) where T<:Real +copy!(dst::Array{<:Real}, src::NDArray) = copy!(dst, copy(src)) + +function copy!(dst::NDArray{T}, src::Array{<:Real}) where {T} @assert dst.writable @assert size(dst) == size(src) - src = convert(Array{eltype(dst)}, src) # this might involve copying + src = convert(Array{T}, src) # this might involve copying @mxcall(:MXNDArraySyncCopyFromCPU, (MX_handle, Ptr{Void}, Csize_t), dst.handle, pointer(src), length(src)) - return dst + dst end -function copy_ignore_shape!(dst :: NDArray, src :: Array{T}) where T<:Real +function copy_ignore_shape!(dst::NDArray{T}, src::Array{<:Real}) where {T} @assert dst.writable @assert length(dst) == length(src) - src = convert(Array{eltype(dst)}, src) # this might involve copying + src = convert(Array{T}, src) # this might involve copying @mxcall(:MXNDArraySyncCopyFromCPU, (MX_handle, Ptr{Void}, Csize_t), dst.handle, pointer(src), length(src)) - return dst + dst end From 7ff7a387e38f2b0a13aca25749c1422c74d2e747 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:14:14 +0800 Subject: [PATCH 05/37] ndarray: refine convert --- src/ndarray.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 2abdb15da..c580f9ef6 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -509,13 +509,12 @@ copy(x::Array{T}, ctx::Context) where {T<:DType} = copy!(empty(T, size(x), ctx), x) """ - convert(::Type{Array{T}}, arr :: NDArray) + convert(::Type{Array{<:Real}}, x::NDArray) -Convert an `NDArray` into a Julia `Array` of specific type. Data will be copied. +Convert an `NDArray` into a Julia `Array` of specific type. +Data will be copied. """ -function convert(t::Type{Array{T}}, arr :: NDArray) where T<:Real - convert(t, copy(arr)) -end +convert(T::Type{Array{<:Real}}, x::NDArray) = convert(T, copy(x)) """ deepcopy(arr::NDArray) From 886b3ad3e29d52cf996af3861088c7c7025b97af Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:25:07 +0800 Subject: [PATCH 06/37] ndarray: refine add_to! --- src/ndarray.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index c580f9ef6..6a108f032 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -570,13 +570,13 @@ end Add a bunch of arguments into `dst`. Inplace updating. """ -function add_to!(dst::NDArray, args::NDArrayOrReal...) +function add_to!(dst::NDArray{T}, args::NDArrayOrReal...) where T @assert dst.writable for arg in args if isa(arg, Real) - _plus_scalar(dst, scalar=convert(eltype(dst), arg), out=dst) + _plus_scalar(dst, scalar = convert(T, arg), out = dst) else - _plus(dst, arg, out=dst) + _plus(dst, arg, out = dst) end end return dst From a7e4d46df05cea92a80e24e2ee6fa038a29febf2 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:34:43 +0800 Subject: [PATCH 07/37] ndarray: refine sub_from! --- src/ndarray.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 6a108f032..3288ff708 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -603,12 +603,12 @@ broadcast_(::typeof(+), x::Real, y::NDArray) = x + y Subtract a bunch of arguments from `dst`. Inplace updating. """ -function sub_from!(dst::NDArray, arg::NDArrayOrReal) +function sub_from!(dst::NDArray{T}, arg::NDArrayOrReal) where T @assert dst.writable if isa(arg, Real) - _minus_scalar(dst, scalar=convert(eltype(dst), arg), out=dst) + _minus_scalar(dst, scalar = convert(T, arg), out = dst) else - _minus(dst, arg, out=dst) + _minus(dst, arg, out = dst) end end From 3929a599a49816226fc07da925be9e42cc345521 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:34:59 +0800 Subject: [PATCH 08/37] ndarray: refine mul_to! --- src/ndarray.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 3288ff708..1ad823e56 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -635,12 +635,12 @@ broadcast_(::typeof(-), x::Real, y::NDArray) = x - y Elementwise multiplication into `dst` of either a scalar or an `NDArray` of the same shape. Inplace updating. """ -function mul_to!(dst::NDArray, arg::NDArrayOrReal) +function mul_to!(dst::NDArray{T}, arg::NDArrayOrReal) where T @assert dst.writable if isa(arg, Real) - _mul_scalar(dst, scalar=convert(eltype(dst), arg), out=dst) + _mul_scalar(dst, scalar = convert(T, arg), out = dst) else - _mul(dst, arg, out=dst) + _mul(dst, arg, out = dst) end end From ca1d69b63cef308cd42809b31da7d8b4f01a6901 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:35:20 +0800 Subject: [PATCH 09/37] ndarray: refine div_from! --- src/ndarray.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 1ad823e56..0e87a9c68 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -663,23 +663,19 @@ broadcast_(::typeof(*), x::Real, y::NDArray) = y .* x Matrix (2D NDArray) multiplication. """ -function *(x::NDArray, y::NDArray) - @assert ndims(x) == 2 - @assert ndims(y) == 2 - dot(x, y) -end +*(x::NDArray{T,2}, y::NDArray{S,2}) where {T,S} = dot(x, y) """ div_from!(dst::NDArray, arg::NDArrayOrReal) Elementwise divide a scalar or an `NDArray` of the same shape from `dst`. Inplace updating. """ -function div_from!(dst::NDArray, arg::NDArrayOrReal) +function div_from!(dst::NDArray{T}, arg::NDArrayOrReal) where {T} @assert dst.writable if isa(arg, Real) - _div_scalar(dst, scalar=convert(eltype(dst), arg), out=dst) + _div_scalar(dst, scalar = convert(T, arg), out = dst) else - _div(dst, arg, out=dst) + _div(dst, arg, out = dst) end end From 2df8321927a64f522d6c557919c539307cbeceaf Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:35:39 +0800 Subject: [PATCH 10/37] ndarray: refine rdiv_from! --- src/ndarray.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 0e87a9c68..73fd243ac 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -684,9 +684,9 @@ end Elementwise divide a scalar by an `NDArray`. Inplace updating. """ -function rdiv_from!(x::Real, y::NDArray) +function rdiv_from!(x::Real, y::NDArray{T}) where {T} @assert y.writable - _rdiv_scalar(y, scalar=convert(eltype(y), x), out=y) + _rdiv_scalar(y, scalar = convert(T, x), out = y) end import Base: / From 8256b26e72b972c1e6333e0b124112332a160e91 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 16:50:56 +0800 Subject: [PATCH 11/37] ndarray: refine _wait_to_read/_wait_to_write --- src/ndarray.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 73fd243ac..bdfc578e4 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -869,12 +869,11 @@ function pointer(arr :: NDArray) @mxcall(:MXNDArrayGetData, (MX_handle, Ref{Ptr{Void}}), arr, pdata) return convert(Ptr{eltype(arr)}, pdata[]) end -function _wait_to_read(arr :: NDArray) + +@inline _wait_to_read(arr :: NDArray) = @mxcall(:MXNDArrayWaitToRead, (MX_handle,), arr) -end -function _wait_to_write(arr :: NDArray) +@inline _wait_to_write(arr :: NDArray) = @mxcall(:MXNDArrayWaitToWrite, (MX_handle,), arr) -end """ try_get_shared(arr; sync=:nop) From 3c68e9ab3bd9d7822116efd469853c9b56e5facf Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 20:46:47 +0800 Subject: [PATCH 12/37] ndarray: refine is_shared --- src/ndarray.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index bdfc578e4..b4b7ea747 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -892,19 +892,19 @@ Try to create a Julia array by sharing the data with the underlying `NDArray`. On CPU, invoke `_wait_to_read` if `:read`; invoke `_wait_to_write` if `:write`. """ -function try_get_shared(arr :: NDArray; sync::Symbol=:nop) - if context(arr).device_type == CPU +function try_get_shared(x::NDArray; sync::Symbol=:nop) + if context(x).device_type == CPU # try to do data sharing if sync == :read - _wait_to_read(arr) + _wait_to_read(x) elseif sync == :write - _wait_to_write(arr) + _wait_to_write(x) end - unsafe_wrap(Array, pointer(arr), size(arr)) + unsafe_wrap(Array, pointer(x), size(x)) else # impossible to share, just copying - copy(arr) + copy(x) end end @@ -918,16 +918,16 @@ Test whether `j_arr` is sharing data with `arr`. * `j_arr::Array`: the Julia Array. * `arr::NDArray`: the `NDArray`. """ -is_shared(j_arr :: Array, arr :: NDArray) = false +is_shared(::Array, ::NDArray) = false -function is_shared(j_arr :: Array{T}, arr :: NDArray) where T<:DType +function is_shared(j_arr::Array{T}, arr::NDArray{T}) where {T<:DType} if length(j_arr) != length(arr) return false end if context(arr).device_type != CPU return false end - return pointer(j_arr) == pointer(arr) + pointer(j_arr) == pointer(arr) end """ From bd700a59baccc137d421bf2a7ff71ae05e2a354a Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 20:57:46 +0800 Subject: [PATCH 13/37] ndarray: refine save --- src/ndarray.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index b4b7ea747..89a6f9e41 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -976,15 +976,14 @@ with corresponding support (see `load`). """ save(filename::String, data::NDArray) = save(filename, [data]) -function save(filename::String, data::Vector{<:NDArray}) +save(filename::String, data::AbstractVector{<:NDArray}) = @mxcall(:MXNDArraySave, (char_p, MX_uint, Ptr{MX_handle}, char_pp), filename, length(data), MX_handle[data...], char_pp(0)) -end -function save(filename::String, data::Dict{Symbol,<:NDArray}) - names = [k for k in keys(data)] - arrays = MX_handle[data[k] for k in names] - names = String[string(k) for k in names] +function save(filename::String, data::Dict{Symbol}) + names = keys(data) + arrays = MX_handle.(collect(values(data))) + names = String.(collect(names)) @mxcall(:MXNDArraySave, (char_p, MX_uint, Ptr{MX_handle}, char_pp), filename, length(names), arrays, names) From a648fa34dce669c11c0ffe43f8d94bf43331d054 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 21:17:31 +0800 Subject: [PATCH 14/37] ndarray: refine dot --- src/ndarray.jl | 4 ++-- test/unittest/ndarray.jl | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 89a6f9e41..cd7e37f2f 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -1004,7 +1004,7 @@ function _autoimport(name::Symbol) end macro _remap(sig::Expr, imp::Expr) - fname = sig.args[1] + fname = (sig.head == :call) ? sig.args[1] : sig.args[1].args[1] # case of `where` opname = string(imp.args[1]) import_expr = _autoimport(fname) @@ -1080,7 +1080,7 @@ _mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse)) @_remap minimum(arr::NDArray, dims) min(arr; axis = 0 .- dims, keepdims = true) # See https://github.com/dmlc/MXNet.jl/issues/55 -@_remap dot(x::NDArray, y::NDArray) dot(y, x) +@_remap dot(x::NDArray{T,N}, y::NDArray{S,N}) where {T,S,N} dot(y, x) # See https://github.com/dmlc/MXNet.jl/pull/123 @_remap transpose(arr::NDArray) transpose(_only2d(arr)) diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 06082e8b7..b4f81eac0 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -586,6 +586,10 @@ function test_dot() y = mx.zeros(dims2) z = mx.dot(x, y) @test size(z) == (2, 8) + + x = mx.zeros(1, 2) + y = mx.zeros(1, 2, 3) + @test_throws MethodError dot(x, y) end function test_eltype() From 0d7faf13051dd88c265fd28493222f29d2020e8a Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 21:25:38 +0800 Subject: [PATCH 15/37] ndarray: VecOfNDArray --- src/executor.jl | 1 - src/ndarray.jl | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index adc079c9c..40f66d95b 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -5,7 +5,6 @@ An executor is a realization of a symbolic architecture defined by a `SymbolicNo The actual forward and backward computation specified by the network architecture can be carried out with an executor. """ -const VecOfNDArray = AbstractVector{<:NDArray} mutable struct Executor{A<:VecOfNDArray, B<:VecOfNDArray, G<:AbstractVector{<:Union{Void,NDArray}}, diff --git a/src/ndarray.jl b/src/ndarray.jl index cd7e37f2f..639405375 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -104,7 +104,9 @@ NDArray(handle, writable = true) = NDArray(x::AbstractArray{T}) where {T<:DType} = copy(collect(x), cpu()) NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu()) +# type aliases const NDArrayOrReal = Union{NDArray, Real} +const VecOfNDArray = AbstractVector{<:NDArray} @unfuse NDArray @@ -971,12 +973,12 @@ Save NDarrays to binary file. Filename could be S3 or HDFS address, if `libmxnet with corresponding support (see `load`). * `filename::String`: path to the binary file to write to. -* `data`: data to save to file. Data can be a`NDArray`, a `Vector{<:NDArray}`, - or a `Dict{Symbol, <:NDArray}`. +* `data`: data to save to file. Data can be a`NDArray`, a `Vector` of `NDArray`, + or a `Dict{Symbol}` contains `NDArray`s. """ save(filename::String, data::NDArray) = save(filename, [data]) -save(filename::String, data::AbstractVector{<:NDArray}) = +save(filename::String, data::VecOfNDArray) = @mxcall(:MXNDArraySave, (char_p, MX_uint, Ptr{MX_handle}, char_pp), filename, length(data), MX_handle[data...], char_pp(0)) From bf21a7fbbb427dc095133dac80cd168d37d7baf7 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 21:36:43 +0800 Subject: [PATCH 16/37] executor: refine backward --- src/executor.jl | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index 40f66d95b..94333ab90 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -175,17 +175,11 @@ function forward(self::Executor; is_train::Bool = false, kwargs...) self.outputs end -function backward(self :: Executor) - backward(self, NDArray[]) -end -function backward(self :: Executor, out_grad :: NDArray) - backward(self, [out_grad]) -end -function backward(self :: Executor, out_grads :: Vector{<:NDArray}) - out_grads = MX_handle[out_grads...] - @mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), self, length(out_grads), out_grads) -end - +backward(x::Executor) = backward(x, NDArray[]) +backward(x::Executor, out_grad::NDArray) = backward(x, [out_grad]) +backward(x::Executor, out_grads::VecOfNDArray) = + @mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), + x, length(out_grads), MX_handle[out_grads...]) function copy_params_from(self::Executor, arg_params::Dict{Symbol,<:NDArray}, aux_params::Union{Void,Dict{Symbol,<:NDArray}}=nothing; From 6bf408d02912f0da62518962819f7dc0e756579f Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sun, 19 Nov 2017 22:54:52 +0800 Subject: [PATCH 17/37] ndarray: refine empty --- src/ndarray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 639405375..ff559be96 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -173,7 +173,7 @@ Create an `NDArray` with similar shape, data type, and context with the given one. Note that the returned `NDArray` is uninitialized. """ -Base.similar(x::NDArray) = empty(eltype(x), size(x), context(x)) +Base.similar(x::NDArray{T}) where {T} = empty(T, size(x), context(x)) """ zeros(DType, dims[, ctx::Context = cpu()]) From a53ea28b4209391b60d9126df103376f11e4140b Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 20 Nov 2017 16:13:31 +0800 Subject: [PATCH 18/37] executor: refine bind --- src/executor.jl | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index 94333ab90..c51ada60a 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -118,16 +118,12 @@ function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{<:NDArr executor = Executor(MX_ExecutorHandle(ref_hdr[]), self, args, args_grad, aux_states) end -function bind(self :: SymbolicNode; kwargs...) + +function bind(x::SymbolicNode; context::Context = cpu(), kwargs...) kwargs = Dict(kwargs) @assert(haskey(kwargs, :args), "Must specify args") args = pop!(kwargs, :args) - if haskey(kwargs, :context) - context = pop!(kwargs, :context) - else - context = cpu() - end - bind(self, context, args; kwargs...) + bind(x, ctx, args; kwargs...) end function simple_bind(self :: SymbolicNode, ctx :: Context; From 42dbee4347f9c7aef55cbfedd086a5bd3422b414 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 22 Nov 2017 00:30:40 +0800 Subject: [PATCH 19/37] refine callback --- src/callback.jl | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/callback.jl b/src/callback.jl index 289fdd066..16e44fdc6 100644 --- a/src/callback.jl +++ b/src/callback.jl @@ -48,7 +48,7 @@ end See also [`every_n_epoch`](@ref) and [`speedometer`](@ref). """ -function every_n_batch(callback :: Function, n :: Int; call_on_0 :: Bool = false) +function every_n_batch(callback::Function, n::Int; call_on_0::Bool = false) BatchCallback(n, call_on_0, callback) end function (cb :: BatchCallback)(state :: OptimizationState) @@ -62,7 +62,7 @@ function (cb :: BatchCallback)(state :: OptimizationState) end """ - speedometer(; frequency=50) + speedometer(;frequency=50) Create an `AbstractBatchCallback` that measure the training speed (number of samples processed per second) every k mini-batches. @@ -71,9 +71,9 @@ Create an `AbstractBatchCallback` that measure the training speed * `frequency::Int`: keyword argument, default 50. The frequency (number of min-batches) to measure and report the speed. """ -function speedometer(;frequency::Int=50) +function speedometer(;frequency::Int = 50) cl_tic = 0 - every_n_batch(frequency, call_on_0=true) do state :: OptimizationState + every_n_batch(frequency, call_on_0 = true) do state::OptimizationState if state.curr_batch == 0 # reset timer cl_tic = time() @@ -104,10 +104,11 @@ A convenient function to construct a callback that runs every `n` full data-pass See also [`every_n_batch`](@ref). """ -function every_n_epoch(callback :: Function, n :: Int; call_on_0 :: Bool = false) +every_n_epoch(callback::Function, n::Int; call_on_0::Bool = false) = EpochCallback(n, call_on_0, callback) -end -function (cb :: EpochCallback)(model :: Any, state :: OptimizationState, metric :: Vector{Tuple{Base.Symbol, T}}) where T<:Real + +function (cb::EpochCallback)(model::Any, state::OptimizationState, + metric::Vector{Tuple{Symbol, T}}) where T<:Real if state.curr_epoch == 0 if cb.call_on_0 cb.callback(model, state, metric) @@ -124,15 +125,17 @@ Create an `AbstractEpochCallback` that save checkpoints of the model to disk. The checkpoints can be loaded back later on. # Arguments -* `prefix::AbstractString`: the prefix of the filenames to save the model. The model - architecture will be saved to prefix-symbol.json, while the weights will be saved - to prefix-0012.params, for example, for the 12-th epoch. -* `frequency::Int`: keyword argument, default 1. The frequency (measured in epochs) to - save checkpoints. +* `prefix::AbstractString`: the prefix of the filenames to save the model. + The model architecture will be saved to prefix-symbol.json, + while the weights will be saved to prefix-0012.params, + for example, for the 12-th epoch. +* `frequency::Int`: keyword argument, default is 1. + The frequency (measured in epochs) to save checkpoints. * `save_epoch_0::Bool`: keyword argument, default false. Whether we should save a - checkpoint for epoch 0 (model initialized but not seen any data yet). + checkpoint for epoch 0 (model initialized but not seen any data yet). """ -function do_checkpoint(prefix::AbstractString; frequency::Int=1, save_epoch_0=false) +function do_checkpoint(prefix::AbstractString; + frequency::Int = 1, save_epoch_0::Bool = false) mkpath(dirname(prefix)) every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state, metric save_checkpoint(model, prefix, state) From fdeabbad0e39c7440096000d2d719129b3b9b691 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 22 Nov 2017 00:34:25 +0800 Subject: [PATCH 20/37] refine executor --- src/executor.jl | 57 ++++++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index c51ada60a..93d59209b 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -9,8 +9,8 @@ mutable struct Executor{A<:VecOfNDArray, B<:VecOfNDArray, G<:AbstractVector{<:Union{Void,NDArray}}, O<:VecOfNDArray, - D<:Dict{Symbol,<:NDArray}, - E<:Dict{Symbol,<:NDArray}} + D<:Dict{Symbol}, + E<:Dict{Symbol}} handle :: MX_ExecutorHandle symbol :: SymbolicNode arg_arrays :: A @@ -20,40 +20,43 @@ mutable struct Executor{A<:VecOfNDArray, arg_dict :: D aux_dict :: E end -function Executor(hdr :: MX_ExecutorHandle, symbol :: SymbolicNode, - arg_arrays :: Vector{<:NDArray}, grad_arrays :: Vector{<:Union{Void,NDArray}}, - aux_arrays :: Vector{<:NDArray}) + +function Executor(hdl::MX_ExecutorHandle, + sym::SymbolicNode, + arg_arrays::VecOfNDArray, + grad_arrays::AbstractVector, + aux_arrays::VecOfNDArray) # get output arrays ref_size = Ref{MX_uint}(0) - ref_hdrs = Ref{Ptr{MX_handle}}(0) + ref_hdls = Ref{Ptr{MX_handle}}(C_NULL) @mxcall(:MXExecutorOutputs, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_handle}}), - hdr, ref_size, ref_hdrs) - out_hdrs = unsafe_wrap(Array, ref_hdrs[], ref_size[]) + hdl, ref_size, ref_hdls) + out_hdrs = unsafe_wrap(Array, ref_hdls[], ref_size[]) out_arrays = [NDArray(MX_NDArrayHandle(x)) for x in out_hdrs] - arg_names = list_arguments(symbol) + arg_names = list_arguments(sym) @assert(length(arg_names) == length(unique(arg_names)), "Duplicated names in arguments: $arg_names") arg_dict = Dict(zip(arg_names, arg_arrays)) - aux_names = list_auxiliary_states(symbol) + aux_names = list_auxiliary_states(sym) @assert(length(aux_names) == length(unique(aux_names)), "Duplicated names in auxiliary states: $aux_names") aux_dict = Dict(zip(aux_names, aux_arrays)) - Executor(hdr, symbol, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict) + Executor(hdl, sym, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict) end -function Base.unsafe_convert(::Type{MX_handle}, obj::Executor) +Base.unsafe_convert(::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(MX_handle, obj.handle) -end Base.convert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj) Base.cconvert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj) -function _get_ndarray_inputs(arg_key::AbstractString, args::Vector{<:NDArray}, +function _get_ndarray_inputs(arg_key::AbstractString, args::VecOfNDArray, arg_names::Vector{Symbol}, allow_missing::Bool) @assert(length(args) == length(arg_names), "Length of $arg_key does not match number of arguments") return (MX_handle[args...], args) end -function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Symbol,<:NDArray}, + +function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Symbol}, arg_names::Vector{Symbol}, allow_missing::Bool) args_vec = map(arg_names) do name arr = get(args, name, nothing) @@ -82,16 +85,16 @@ Create an `Executor` by binding a `SymbolicNode` to concrete `NDArray`. * `ctx::Context`: the context on which the computation should run. * `args`: either a list of `NDArray` or a dictionary of name-array pairs. Concrete arrays for all the inputs in the network architecture. The inputs typically include - network parameters (weights, bias, filters, etc.), data and labels. See [`list_arguments`](@ref) - and [`infer_shape`](@ref). -* `args_grad`: -* `aux_states`: -* `grad_req`: + network parameters (weights, bias, filters, etc.), data and labels. + See [`list_arguments`](@ref) and [`infer_shape`](@ref). +* `args_grad`: a `Vector` of `NDArray` or a `Dict` contains `NDArray` +* `aux_states`: a `Vector` of `NDArray` or a `Dict` contains `NDArray` +* `grad_req`: single value, a `Vector` of `GRAD_REQ` or a `Dict{Symbol,GRAD_REQ}` """ -function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{<:NDArray},Dict{Symbol,<:NDArray}}; - args_grad :: Union{Vector{<:NDArray},Dict{Symbol,<:NDArray}} = Dict{Symbol,NDArray}(), - aux_states :: Union{Vector{<:NDArray},Dict{Symbol,<:NDArray}} = Dict{Symbol,NDArray}(), - grad_req :: Union{GRAD_REQ,Vector{GRAD_REQ},Dict{Base.Symbol,GRAD_REQ}} = GRAD_WRITE) +function bind(self::SymbolicNode, ctx::Context, args; + args_grad = Dict{Symbol,NDArray}(), + aux_states = Dict{Symbol,NDArray}(), + grad_req = GRAD_WRITE) arg_names = list_arguments(self) @@ -123,11 +126,11 @@ function bind(x::SymbolicNode; context::Context = cpu(), kwargs...) kwargs = Dict(kwargs) @assert(haskey(kwargs, :args), "Must specify args") args = pop!(kwargs, :args) - bind(x, ctx, args; kwargs...) + bind(x, context, args; kwargs...) end -function simple_bind(self :: SymbolicNode, ctx :: Context; - grad_req :: Union{GRAD_REQ, Dict{Symbol, GRAD_REQ}}=GRAD_WRITE, +function simple_bind(self::SymbolicNode, ctx::Context; + grad_req::Union{GRAD_REQ,Dict{Symbol,GRAD_REQ}} = GRAD_WRITE, kwargs...) arg_shapes, out_shapes, aux_shapes = infer_shape(self; kwargs...) @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") From 8bde97e12181c964bc323bbbde4e5c3397e76585 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 22 Nov 2017 00:34:43 +0800 Subject: [PATCH 21/37] refine kvstore --- src/kvstore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kvstore.jl b/src/kvstore.jl index 2e424e6fb..bdd0902b1 100644 --- a/src/kvstore.jl +++ b/src/kvstore.jl @@ -6,7 +6,7 @@ mutable struct KVStore KVStore(hdr :: MX_KVStoreHandle) = new(hdr, Ptr{Void}(0)) end -function KVStore(kv_type::Base.Symbol = :local) +function KVStore(kv_type::Symbol = :local) #@assert(kv_type ∈ [:local]) # TODO: update with allowed types ref_hdr = Ref{MX_handle}(0) From 658caeb504f6b4b1f9d16c150e78fa89a23db320 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 22 Nov 2017 00:36:26 +0800 Subject: [PATCH 22/37] refine model.jl --- src/model.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8845f65ad..b0004d800 100644 --- a/src/model.jl +++ b/src/model.jl @@ -18,13 +18,13 @@ mutable struct FeedForward <: AbstractModel arch :: SymbolicNode ctx :: Vector{Context} - arg_params :: Dict{Symbol,<:NDArray} - aux_params :: Dict{Symbol,<:NDArray} + arg_params :: Dict{Symbol} + aux_params :: Dict{Symbol} pred_exec :: Union{Executor, Void} # leave the rest fields undefined - FeedForward(arch :: SymbolicNode, ctx :: Vector{Context}) = new(arch, ctx) + FeedForward(arch::SymbolicNode, ctx::Vector{Context}) = new(arch, ctx) end """ @@ -264,7 +264,7 @@ function _init_model(self :: FeedForward, data :: AbstractDataProvider, initiali init_model(self, initializer; overwrite=overwrite, [provide_data(data)..., provide_label(data)...]...) end -function _create_kvstore(kv_type :: Base.Symbol, num_device :: Int, arg_params :: Dict{Base.Symbol,NDArray}, verbosity :: Int) +function _create_kvstore(kv_type::Symbol, num_device::Int, arg_params::Dict{Symbol}, verbosity :: Int) if num_device == 1 && !ismatch(r"dist", string(kv_type)) return nothing else @@ -286,7 +286,7 @@ end n_epoch :: Int = 10, eval_data :: Union{Void, AbstractDataProvider} = nothing, eval_metric :: AbstractEvalMetric = Accuracy(), - kvstore :: Union{Base.Symbol, KVStore} = :local, + kvstore :: Union{Symbol, KVStore} = :local, force_init :: Bool = false, callbacks :: Vector{AbstractCallback} = AbstractCallback[], verbosity :: Int = 3 @@ -294,7 +294,7 @@ end function _invoke_callbacks(self::FeedForward, callbacks::Vector{AbstractCallback}, state::OptimizationState, type_filter::Type; - metric::Vector{Tuple{Base.Symbol, T}} = Vector{Tuple{Base.Symbol, Real}}()) where T<:Real + metric::Vector{Tuple{Symbol, T}} = Vector{Tuple{Symbol, Real}}()) where T<:Real map(callbacks) do cb if isa(cb, type_filter) if type_filter == AbstractEpochCallback @@ -332,7 +332,7 @@ Train the `model` on `data` with the `optimizer`. calculated on the validation set. * `kvstore`: keyword argument, default `:local`. The key-value store used to synchronize gradients and parameters when multiple devices are used for training. - :type kvstore: `KVStore` or `Base.Symbol` + :type kvstore: `KVStore` or `Symbol` * `initializer::AbstractInitializer`: keyword argument, default `UniformInitializer(0.01)`. * `force_init::Bool`: keyword argument, default false. By default, the random initialization using the provided `initializer` will be skipped if the model weights already exists, maybe from a previous @@ -362,7 +362,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra # setup kvstore kvstore = opts.kvstore - if isa(kvstore, Base.Symbol) + if isa(kvstore, Symbol) opts.verbosity >= 2 && info("Creating KVStore...") kvstore = _create_kvstore(kvstore, length(self.ctx), self.arg_params, opts.verbosity) end @@ -622,7 +622,7 @@ end Load a mx.FeedForward model from the checkpoint *prefix*, *epoch* and optionally provide a context. """ -function load_checkpoint(prefix :: AbstractString, epoch :: Int, ::Type{FeedForward}; context = nothing) +function load_checkpoint(prefix::AbstractString, epoch::Int, ::Type{FeedForward}; context = nothing) arch, arg_params, aux_params = load_checkpoint(prefix, epoch) model = FeedForward(arch, context = context) model.arg_params = arg_params From 6ccb7137d299595ef20037009e69aab815a989fb Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 22 Nov 2017 00:37:31 +0800 Subject: [PATCH 23/37] fix mnist-test --- src/model.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/model.jl b/src/model.jl index b0004d800..51d7bd5da 100644 --- a/src/model.jl +++ b/src/model.jl @@ -586,20 +586,22 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra nothing end -function save_checkpoint(self :: FeedForward, prefix :: AbstractString, state :: OptimizationState) +save_checkpoint(self::FeedForward, prefix::AbstractString, state::OptimizationState) = save_checkpoint(self.arch, self.arg_params, self.aux_params, prefix, state.curr_epoch) -end -function save_checkpoint(sym :: SymbolicNode, arg_params :: Dict{Base.Symbol, NDArray}, - aux_params :: Dict{Base.Symbol, NDArray}, prefix :: AbstractString, epoch :: Int) + +function save_checkpoint(sym::SymbolicNode, arg_params::Dict{Symbol}, + aux_params::Dict{Symbol}, prefix::AbstractString, epoch::Int) save("$prefix-symbol.json", sym) - save_dict = merge(Dict{Base.Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params)), - Dict{Base.Symbol, NDArray}(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params))) + save_dict = Dict{Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params)) + if !isempty(aux_params) + merge!(save_dict, Dict(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params))) + end save_filename = format("{1}-{2:04d}.params", prefix, epoch) save(save_filename, save_dict) info("Saved checkpoint to '$save_filename'") end -function load_checkpoint(prefix :: AbstractString, epoch :: Int) +function load_checkpoint(prefix::AbstractString, epoch::Int) arch = load("$prefix-symbol.json", SymbolicNode) saved_dict = load(format("{1}-{2:04d}.params", prefix, epoch), NDArray) arg_params = Dict{Base.Symbol, NDArray}() From 12b36a080e62798c3538567817eed4628771e7c2 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Fri, 24 Nov 2017 20:07:08 +0800 Subject: [PATCH 24/37] metrics --- src/metric.jl | 94 +++++++++++++++++++++------------------------------ 1 file changed, 39 insertions(+), 55 deletions(-) diff --git a/src/metric.jl b/src/metric.jl index db38060c9..93af9ef31 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -30,11 +30,12 @@ Update and accumulate metrics. * `labels::Vector{NDArray}`: the labels from the data provider. * `preds::Vector{NDArray}`: the outputs (predictions) of the network. """ -function update!(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray}) where T <: AbstractEvalMetric +function update!(metric::T, labels::VecOfNDArray, preds::VecOfNDArray) where T <: AbstractEvalMetric _update!(metric, labels, preds, hasNDArraySupport(metric)) end -function _update!(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray}, :: Val{true}) where T<: AbstractEvalMetric +function _update!(metric::T, labels::VecOfNDArray, preds::VecOfNDArray, + ::Val{true}) where T<: AbstractEvalMetric if length(labels) != length(preds) Base.warn_once( "The number of labels ($(length(labels))) does not correspond to the\ @@ -45,7 +46,8 @@ function _update!(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArra end end -function _update!(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray}, :: Val{false}) where T<: AbstractEvalMetric +function _update!(metric::T, labels::VecOfNDArray, preds::VecOfNDArray, + ::Val{false}) where T<: AbstractEvalMetric if length(labels) != length(preds) Base.warn_once( "The number of labels ($(length(labels))) does not correspond to the\ @@ -65,9 +67,7 @@ end Reset the accumulation counter. """ -function reset!(metric :: AbstractEvalMetric) - throw(MethodError(reset!, (typeof(metric),))) -end +reset!(metric::AbstractEvalMetric) = throw(MethodError(reset!, (typeof(metric),))) import Base: get @@ -79,9 +79,7 @@ Get the accumulated metrics. Returns `Vector{Tuple{Base.Symbol, Real}}`, a list of name-value pairs. For example, `[(:accuracy, 0.9)]`. """ -function get(metric :: AbstractEvalMetric) - throw(MethodError(get, (typeof(metric),))) -end +get(metric::AbstractEvalMetric) = throw(MethodError(get, (typeof(metric),))) """ NullMetric() @@ -91,17 +89,11 @@ A metric that calculates nothing. Can be used to ignore an output during trainin mutable struct NullMetric <: mx.AbstractEvalMetric end -function update!(metric :: NullMetric, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - return nothing -end +update!(metric::NullMetric, labels::VecOfNDArray, preds::VecOfNDArray) = nothing -function reset!(metric :: NullMetric) - return nothing -end +reset!(metric::NullMetric) = nothing -function get(metric :: NullMetric) - return Tuple{Symbol, Float64}[] -end +get(metric::NullMetric) = Tuple{Symbol, Float64}[] """ MultiMetric(metrics::Vector{AbstractEvalMetric}) @@ -118,21 +110,19 @@ mutable struct MultiMetric <: mx.AbstractEvalMetric metrics :: Vector{mx.AbstractEvalMetric} end -function update!(metric :: MultiMetric, labels :: Vector{NDArray}, preds :: Vector{NDArray}) +function update!(metric :: MultiMetric, labels :: Vector{<:NDArray}, preds :: Vector{<:NDArray}) for m in metric.metrics update!(m, labels, preds) end - return nothing + nothing end function reset!(metric :: MultiMetric) map(reset!, metric.metrics) - return nothing + nothing end -function get(metric :: MultiMetric) - mapreduce(get, append!, metric.metrics) -end +get(metric :: MultiMetric) = mapreduce(get, append!, metric.metrics) """ SeqMetric(metrics::Vector{AbstractEvalMetric}) @@ -150,23 +140,21 @@ mutable struct SeqMetric <: mx.AbstractEvalMetric metrics :: Vector{mx.AbstractEvalMetric} end -function update!(metric :: SeqMetric, labels :: Vector{NDArray}, preds :: Vector{NDArray}) +function update!(metric::SeqMetric, labels::VecOfNDArray, preds::VecOfNDArray) @assert length(metric.metrics) == length(labels) @assert length(metric.metrics) == length(preds) for (m, l, p) in zip(metric.metrics, labels, preds) update!(m, [l], [p]) end - return nothing + nothing end -function reset!(metric :: SeqMetric) +function reset!(metric::SeqMetric) map(reset!, metric.metrics) - return nothing + nothing end -function get(metric :: SeqMetric) - mapreduce(get, append!, metric.metrics) -end +get(metric::SeqMetric) = mapreduce(get, append!, metric.metrics) """ Accuracy @@ -185,7 +173,7 @@ end hasNDArraySupport(::Accuracy) = Val{false}() -function _update_single_output(metric :: Accuracy, label :: Array, pred :: Array) +function _update_single_output(metric::Accuracy, label::Array, pred::Array) # Samples are stored in the last dimension @assert size(label, ndims(label)) == size(pred, ndims(pred)) @@ -217,9 +205,7 @@ function _update_single_output(metric :: Accuracy, label :: Array, pred :: Array end end -function get(metric :: Accuracy) - return [(:accuracy, metric.acc_sum / metric.n_sample)] -end +get(metric::Accuracy) = [(:accuracy, metric.acc_sum / metric.n_sample)] function reset!(metric :: Accuracy) metric.acc_sum = 0.0 @@ -235,31 +221,34 @@ Calculates the mean squared error regression loss. Requires that label and prediction have the same shape. """ -mutable struct MSE <: AbstractEvalMetric - mse_sum :: Vector{NDArray} +mutable struct MSE{T<:NDArray} <: AbstractEvalMetric + mse_sum :: Vector{T} n_sample :: Int - MSE() = new(Vector{NDArray}(), 0) + MSE{T}() where {T<:NDArray} = new(Vector{T}(), 0) end +MSE() = MSE{NDArray}() # backward compat? + hasNDArraySupport(::MSE) = Val{true}() -function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray) +function _update_single_output(metric::MSE, label::NDArray{T,N}, + pred::NDArray{T,N}) where {T,N} @assert size(label) == size(pred) metric.n_sample += length(label) - mse_sum = mx.sum(mx._PowerScalar(label - pred,scalar=2)) + mse_sum = mx.sum((label .- pred).^2) push!(metric.mse_sum, mse_sum) - return nothing + nothing end -function get(metric :: MSE) +function get(metric::MSE) # Delay copy until last possible moment mse_sum = mapreduce(nda->copy(nda)[1], +, 0.0, metric.mse_sum) - return [(:MSE, mse_sum / metric.n_sample)] + [(:MSE, mse_sum / metric.n_sample)] end -function reset!(metric :: MSE) - metric.mse_sum = Vector{NDArray}() +function reset!(metric::MSE{T}) where T + metric.mse_sum = Vector{T}() metric.n_sample = 0 end @@ -319,7 +308,7 @@ end hasNDArraySupport(::NMSE) = Val{false}() -function _update_single_output(metric :: NMSE, label :: Array, pred :: Array) +function _update_single_output(metric::NMSE, label::Array, pred::Array) n_sample = size(pred)[end] metric.n_sample += n_sample @@ -332,11 +321,9 @@ function _update_single_output(metric :: NMSE, label :: Array, pred :: Array) end end -function get(metric :: NMSE) - return [(:NMSE, metric.nmse_sum / metric.n_sample)] -end +get(metric::NMSE) = [(:NMSE, metric.nmse_sum / metric.n_sample)] -function reset!(metric :: NMSE) +function reset!(metric::NMSE) metric.nmse_sum = 0.0 metric.n_sample = 0 end @@ -357,11 +344,9 @@ mutable struct ACE <: AbstractEvalMetric ACE(eps=1.0e-8) = new(0.0, 0, eps) end -function get(metric :: ACE) - return [(:ACE, - metric.ace_sum / metric.n_sample)] -end +get(metric::ACE) = [(:ACE, - metric.ace_sum / metric.n_sample)] -function reset!(metric :: ACE) +function reset!(metric::ACE) metric.ace_sum = 0.0 metric.n_sample = 0 end @@ -474,4 +459,3 @@ function _update_single_output(metric :: MultiACE, label :: Array{T}, pred :: Ar error("Can't handle prediction with dimensions $(ndims(pred)).") end end - From 4956d0c2dc990392ba7b3b76879e7c393c94c0be Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Fri, 24 Nov 2017 21:51:37 +0800 Subject: [PATCH 25/37] io --- src/io.jl | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/io.jl b/src/io.jl index 53101590d..d77b6ff7d 100644 --- a/src/io.jl +++ b/src/io.jl @@ -113,14 +113,19 @@ function get_label end A basic subclass of `AbstractDataBatch`, that implement the interface by accessing member fields. """ -mutable struct DataBatch{V<:AbstractVector{<:NDArray}} <: AbstractDataBatch - data :: V - label :: V +mutable struct DataBatch{T,S,N,M} <: AbstractDataBatch + data :: Vector{NDArray{T,N}} + label :: Vector{NDArray{S,M}} count :: Int end -count_samples(batch :: DataBatch) = batch.count -get_data(::Provider, batch :: DataBatch) where {Provider<:AbstractDataProvider} = batch.data -get_label(::Provider, batch :: DataBatch) where {Provider<:AbstractDataProvider} = batch.label + +count_samples(batch::DataBatch) = batch.count + +get_data(::Provider, batch::DataBatch) where {Provider<:AbstractDataProvider} = + batch.data + +get_label(::Provider, batch::DataBatch) where {Provider<:AbstractDataProvider} = + batch.label """ SlicedNDArray @@ -254,10 +259,10 @@ Construct a data provider from `NDArray` or Julia Arrays. TODO: remove `data_padding` and `label_padding`, and implement rollover that copies the last or first several training samples to feed the padding. """ -mutable struct ArrayDataProvider <: AbstractDataProvider - data_arrays +mutable struct ArrayDataProvider{T,N,S,M} <: AbstractDataProvider + data_arrays :: Vector{Array{T,N}} data_names :: Vector{Symbol} - label_arrays + label_arrays :: Vector{Array{S,M}} label_names :: Vector{Symbol} batch_size :: Int sample_count :: Int @@ -265,8 +270,8 @@ mutable struct ArrayDataProvider <: AbstractDataProvider data_padding :: MX_float label_padding :: MX_float - data_batch - label_batch + data_batch :: Vector{NDArray{T,N}} + label_batch :: Vector{NDArray{S,M}} end # Julia's type system is sometimes very frustrating. You cannot specify a function @@ -458,8 +463,8 @@ function _get_label(handle :: MX_DataIterHandle) end function MXDataProvider(handle :: MX_DataIterHandle; - data_name :: Base.Symbol=:data, - label_name :: Union{Base.Symbol,Void}=:softmax_label, + data_name :: Symbol = :data, + label_name :: Union{Symbol,Void} = :softmax_label, kwargs...) # for convenience, we ignore the rest keyword arguments # init iterator, load the first batch and get shapes @assert(_iter_next(handle), "Failed to load the first batch in MXDataProvider") From 16bd7f2fa0023f76d18c4c3b583975f1c2e51a6b Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Fri, 24 Nov 2017 21:52:23 +0800 Subject: [PATCH 26/37] model --- src/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 51d7bd5da..da1b437b2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -604,8 +604,8 @@ end function load_checkpoint(prefix::AbstractString, epoch::Int) arch = load("$prefix-symbol.json", SymbolicNode) saved_dict = load(format("{1}-{2:04d}.params", prefix, epoch), NDArray) - arg_params = Dict{Base.Symbol, NDArray}() - aux_params = Dict{Base.Symbol, NDArray}() + arg_params = Dict{Symbol,Any}() + aux_params = Dict{Symbol,Any}() for (k,v) in saved_dict tp, name = split(string(k), ':') name = Symbol(name) From c07a6a1d302240adc2744db48bade2a8b5615d68 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Fri, 24 Nov 2017 21:52:39 +0800 Subject: [PATCH 27/37] typo --- src/metric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metric.jl b/src/metric.jl index 93af9ef31..e3556dea2 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -14,7 +14,7 @@ abstract type AbstractEvalMetric end hasNDArraySupport(metric) -> Val{true/false} Trait for `_update_single_output` should return `Val{true}() if metric can handle `NDArray` -directly and `Val{false}()i` if requires `Array`. Metric that work with NDArrays can be +directly and `Val{false}()` if requires `Array`. Metric that work with NDArrays can be async, while native Julia arrays require that we copy the output of the network, which is a blocking operation. """ From 4177125facd2641f3e797a90cb37e76f94a32f5b Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Fri, 24 Nov 2017 21:52:54 +0800 Subject: [PATCH 28/37] executor --- src/executor.jl | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index 93d59209b..5f0996558 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -5,26 +5,19 @@ An executor is a realization of a symbolic architecture defined by a `SymbolicNo The actual forward and backward computation specified by the network architecture can be carried out with an executor. """ -mutable struct Executor{A<:VecOfNDArray, - B<:VecOfNDArray, - G<:AbstractVector{<:Union{Void,NDArray}}, - O<:VecOfNDArray, - D<:Dict{Symbol}, - E<:Dict{Symbol}} +mutable struct Executor handle :: MX_ExecutorHandle symbol :: SymbolicNode - arg_arrays :: A - grad_arrays :: G - aux_arrays :: B - outputs :: O - arg_dict :: D - aux_dict :: E + arg_arrays :: VecOfNDArray + grad_arrays :: Vector{Union{Void,<:NDArray}} + aux_arrays :: VecOfNDArray + outputs :: VecOfNDArray + arg_dict :: Dict{Symbol} + aux_dict :: Dict{Symbol} end -function Executor(hdl::MX_ExecutorHandle, - sym::SymbolicNode, - arg_arrays::VecOfNDArray, - grad_arrays::AbstractVector, +function Executor(hdl::MX_ExecutorHandle, sym::SymbolicNode, + arg_arrays::VecOfNDArray, grad_arrays::AbstractVector, aux_arrays::VecOfNDArray) # get output arrays ref_size = Ref{MX_uint}(0) From 4bbef45fca1ca451667f9fffeafd7fe95b1e0fa9 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Fri, 24 Nov 2017 22:16:14 +0800 Subject: [PATCH 29/37] io --- src/io.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/io.jl b/src/io.jl index d77b6ff7d..ac37a8d84 100644 --- a/src/io.jl +++ b/src/io.jl @@ -259,10 +259,10 @@ Construct a data provider from `NDArray` or Julia Arrays. TODO: remove `data_padding` and `label_padding`, and implement rollover that copies the last or first several training samples to feed the padding. """ -mutable struct ArrayDataProvider{T,N,S,M} <: AbstractDataProvider +mutable struct ArrayDataProvider{T,N} <: AbstractDataProvider data_arrays :: Vector{Array{T,N}} data_names :: Vector{Symbol} - label_arrays :: Vector{Array{S,M}} + label_arrays label_names :: Vector{Symbol} batch_size :: Int sample_count :: Int @@ -270,8 +270,8 @@ mutable struct ArrayDataProvider{T,N,S,M} <: AbstractDataProvider data_padding :: MX_float label_padding :: MX_float - data_batch :: Vector{NDArray{T,N}} - label_batch :: Vector{NDArray{S,M}} + data_batch + label_batch end # Julia's type system is sometimes very frustrating. You cannot specify a function From f867d9282b320de2dc4da0747bbd486c7646dd1c Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sat, 25 Nov 2017 15:26:11 +0800 Subject: [PATCH 30/37] MSE --- src/metric.jl | 12 ++++++------ src/ndarray.jl | 3 --- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/metric.jl b/src/metric.jl index e3556dea2..3998af8ef 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -221,14 +221,14 @@ Calculates the mean squared error regression loss. Requires that label and prediction have the same shape. """ -mutable struct MSE{T<:NDArray} <: AbstractEvalMetric - mse_sum :: Vector{T} +mutable struct MSE{N} <: AbstractEvalMetric + mse_sum :: Vector{NDArray{MX_float,N}} n_sample :: Int - MSE{T}() where {T<:NDArray} = new(Vector{T}(), 0) + MSE{N}() where {N} = new(Vector{NDArray{MX_float,N}}(), 0) end -MSE() = MSE{NDArray}() # backward compat? +MSE() = MSE{1}() # backward compat? hasNDArraySupport(::MSE) = Val{true}() @@ -247,8 +247,8 @@ function get(metric::MSE) [(:MSE, mse_sum / metric.n_sample)] end -function reset!(metric::MSE{T}) where T - metric.mse_sum = Vector{T}() +function reset!(metric::MSE{N}) where N + metric.mse_sum = Vector{NDArray{Float32,N}}() metric.n_sample = 0 end diff --git a/src/ndarray.jl b/src/ndarray.jl index ff559be96..a415de8f2 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -101,9 +101,6 @@ NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu()) NDArray(handle, writable = true) = NDArray{eltype(handle), ndims(handle)}(handle, writable) -NDArray(x::AbstractArray{T}) where {T<:DType} = copy(collect(x), cpu()) -NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu()) - # type aliases const NDArrayOrReal = Union{NDArray, Real} const VecOfNDArray = AbstractVector{<:NDArray} From dce2a4e0efd3440dcd5720071af058498ceb6e45 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sat, 25 Nov 2017 15:29:00 +0800 Subject: [PATCH 31/37] style --- src/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index da1b437b2..99c5b534d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -97,8 +97,8 @@ function init_model(self :: FeedForward, initializer :: AbstractInitializer; ove self.aux_params = Dict{Symbol, NDArray}() end - arg_params = Dict{Symbol, NDArray}() - aux_params = Dict{Symbol, NDArray}() + arg_params = Dict{Symbol,NDArray}() + aux_params = Dict{Symbol,NDArray}() for (name, shape) in filter(x -> in(x[1],param_names), zip(arg_names, arg_shapes)) if haskey(self.arg_params, name) From 0fce4f675be96d146e1757a949986a308b6eef72 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Tue, 28 Nov 2017 01:33:24 +0800 Subject: [PATCH 32/37] refine copy_params_from --- src/executor.jl | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/executor.jl b/src/executor.jl index 5f0996558..cd4a9256e 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -173,9 +173,9 @@ backward(x::Executor, out_grads::VecOfNDArray) = @mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), x, length(out_grads), MX_handle[out_grads...]) -function copy_params_from(self::Executor, arg_params::Dict{Symbol,<:NDArray}, - aux_params::Union{Void,Dict{Symbol,<:NDArray}}=nothing; - allow_extra_params::Bool=false) +function copy_params_from(self::Executor, arg_params::Dict{Symbol}, + aux_params::Dict{Symbol} = Dict{Symbol,Any}(); + allow_extra_params::Bool = false) for (name, array) in arg_params if haskey(self.arg_dict, name) copy!(self.arg_dict[name], array) @@ -184,13 +184,11 @@ function copy_params_from(self::Executor, arg_params::Dict{Symbol,<:NDArray}, end end - if !isa(aux_params, Void) - for (name, array) in aux_params - if haskey(self.aux_dict, name) - copy!(self.aux_dict[name], array) - else - @assert(allow_extra_params, "Extra auxiliary state $name not recognized") - end + for (name, array) in aux_params + if haskey(self.aux_dict, name) + copy!(self.aux_dict[name], array) + else + @assert(allow_extra_params, "Extra auxiliary state $name not recognized") end end end From 8a4381015116c724070d5134b05f50c4ef8d374d Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 29 Nov 2017 12:04:04 +0800 Subject: [PATCH 33/37] io: style stuff --- src/io.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/io.jl b/src/io.jl index ac37a8d84..6c6250436 100644 --- a/src/io.jl +++ b/src/io.jl @@ -357,7 +357,7 @@ function ArrayDataProvider(data::Any, label::Any; batch_size::Int=0, shuffle::Bo ArrayDataProvider(data_arrays, data_names, label_arrays, label_names, batch_size, sample_count, shuffle, MX_float(data_padding), MX_float(label_padding), - data_batch, label_batch) + data_batch, label_batch) end function provide_data(provider::ArrayDataProvider) @@ -374,9 +374,7 @@ struct ArrayDataProviderState <: AbstractDataProviderState curr_idx :: Int end -function Base.eltype(provider :: ArrayDataProvider) - ArrayDataProviderState -end +Base.eltype(provider :: ArrayDataProvider) = ArrayDataProviderState function Base.start(provider :: ArrayDataProvider) if provider.shuffle @@ -389,9 +387,8 @@ function Base.start(provider :: ArrayDataProvider) return ArrayDataProviderState(1) end -function Base.done(provider::ArrayDataProvider, state :: ArrayDataProviderState) - return state.curr_idx > provider.sample_count -end +Base.done(provider::ArrayDataProvider, state::ArrayDataProviderState) = + state.curr_idx > provider.sample_count struct ArrayDataBatch <: AbstractDataBatch idx :: UnitRange{Int} @@ -433,8 +430,8 @@ a list of built-in data iterators. """ mutable struct MXDataProvider <: AbstractDataProvider handle :: MX_DataIterHandle - data_shape :: Vector{Tuple{Symbol, Tuple}} - label_shape:: Vector{Tuple{Symbol, Tuple}} + data_shape :: Vector{Tuple{Symbol,Tuple}} + label_shape:: Vector{Tuple{Symbol,Tuple}} batch_size :: Int # those two a auxiliary variables to help avoid calling reset From cbd47b6a5aab8db6b653a0d3efb0b1bde502a0e5 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 29 Nov 2017 12:04:42 +0800 Subject: [PATCH 34/37] ndarray: fix _remap --- src/ndarray.jl | 2 +- src/util.jl | 14 ++++++++++++++ test/unittest/util.jl | 25 +++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 test/unittest/util.jl diff --git a/src/ndarray.jl b/src/ndarray.jl index 2fda482d4..1a2ffa280 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -1036,7 +1036,7 @@ macro _remap(sig::Expr, imp::Expr) ndhlds = Expr(:vect, map(x -> :($(x).handle), ndin)...) # handler for `func!` which has side effect on first argument. - T, n_output, hdls_ref, retexpr = _outexpr(fname, sig.args[2].args[1]) + T, n_output, hdls_ref, retexpr = _outexpr(fname, _firstarg(sig)) func_body = quote op_handle = _get_cached_libmx_op_handle($opname) diff --git a/src/util.jl b/src/util.jl index 6877200d8..b0f91c824 100644 --- a/src/util.jl +++ b/src/util.jl @@ -202,3 +202,17 @@ function _sig_checker() end end + +""" +Get first position argument from function sig +""" +function _firstarg(sig::Expr) + if sig.head ∈ (:where, :(::)) + _firstarg(sig.args[1]) + elseif sig.head == :call + i = (sig.args[2] isa Expr && sig.args[2].head == :parameters) ? 3 : 2 + _firstarg(sig.args[i]) + end +end + +_firstarg(s::Symbol) = s diff --git a/test/unittest/util.jl b/test/unittest/util.jl new file mode 100644 index 000000000..d27b509bd --- /dev/null +++ b/test/unittest/util.jl @@ -0,0 +1,25 @@ +module TestUtil + +using Base.Test + +using MXNet + + +function test_firstarg() + info("Util::_firstarg") + @test mx._firstarg(:(f(x, y))) == :x + @test mx._firstarg(:(f(x::mx.NDArray, y))) == :x + @test mx._firstarg(:(f(x::mx.NDArray, y::mx.NDArray))) == :x + @test mx._firstarg(:(f(x::Int, y::mx.NDArray))) == :x + @test mx._firstarg(:(f(x::Int, y::mx.NDArray; other = 42))) == :x + @test mx._firstarg(:(f(x::mx.NDArray{T}, y) where {T})) == :x + @test mx._firstarg(:(f(x::mx.NDArray{T,N}, y) where {T,N})) == :x + @test mx._firstarg(:(f(x::mx.NDArray{T,N} where {T,N}, y))) == :x +end # function test_firstarg + + +@testset "Util Test" begin + test_firstarg() +end # @testset "Util" + +end # module TestUtil From fe8c25154209113a49ca96001b6aeb85e9c86605 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 29 Nov 2017 12:20:15 +0800 Subject: [PATCH 35/37] io: style stuff --- src/io.jl | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/io.jl b/src/io.jl index 6c6250436..597ea8a90 100644 --- a/src/io.jl +++ b/src/io.jl @@ -205,7 +205,7 @@ import Base.get Returns the corresponding data array corresponding to that name. """ -function get(provider :: AbstractDataProvider, batch :: AbstractDataBatch, name :: Base.Symbol) +function get(provider::AbstractDataProvider, batch::AbstractDataBatch, name::Symbol) for (idx, (k, s)) in enumerate(provide_data(provider)) if name == k return get_data(provider, batch)[idx] @@ -230,7 +230,7 @@ when you need to perform real-time augmentation of the data. instance after modifying its fields. """ -eachbatch(provider :: AbstractDataProvider) = provider +eachbatch(provider::AbstractDataProvider) = provider """ ArrayDataProvider @@ -280,10 +280,14 @@ end # results, about the parametric type in the Pair{T1,T2} type, thus does not match the # generic Pair type. In general, Int <: Number but Vector{Int} <: Vector{Number} is not # true. So let us just use Any here... -function ArrayDataProvider(data::Any; batch_size::Int=0, shuffle::Bool=false, data_padding::Real=0, label_padding::Real=0) - ArrayDataProvider(data, [], batch_size=batch_size, shuffle=shuffle, data_padding=data_padding, label_padding=label_padding) +function ArrayDataProvider(data; batch_size::Int = 0, shuffle::Bool = false, + data_padding::Real = 0, label_padding::Real = 0) + ArrayDataProvider(data, [], batch_size = batch_size, shuffle = shuffle, + data_padding = data_padding, label_padding = label_padding) end -function ArrayDataProvider(data::Any, label::Any; batch_size::Int=0, shuffle::Bool=false, data_padding::Real=0, label_padding::Real=0) + +function ArrayDataProvider(data, label; batch_size::Int = 0, shuffle::Bool = false, + data_padding::Real = 0, label_padding::Real = 0) asarr(arr :: Array{T}) where {T} = convert(Array{MX_float}, arr) asarr(arr :: NDArray) = copy(arr) @@ -360,13 +364,11 @@ function ArrayDataProvider(data::Any, label::Any; batch_size::Int=0, shuffle::Bo data_batch, label_batch) end -function provide_data(provider::ArrayDataProvider) - return collect(zip(provider.data_names, map(size, provider.data_batch))) -end +provide_data(provider::ArrayDataProvider) = + collect(zip(provider.data_names, map(size, provider.data_batch))) -function provide_label(provider::ArrayDataProvider) - return collect(zip(provider.label_names, map(size, provider.label_batch))) -end +provide_label(provider::ArrayDataProvider) = + collect(zip(provider.label_names, map(size, provider.label_batch))) get_batch_size(provider::ArrayDataProvider) = provider.batch_size From f98a132d7435831ad68ec17f30f962e72a028ee5 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 29 Nov 2017 12:29:15 +0800 Subject: [PATCH 36/37] kvstore --- src/kvstore.jl | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/kvstore.jl b/src/kvstore.jl index bdd0902b1..fa4768cce 100644 --- a/src/kvstore.jl +++ b/src/kvstore.jl @@ -3,7 +3,7 @@ mutable struct KVStore updater_c :: Ptr{Void} updater :: Function - KVStore(hdr :: MX_KVStoreHandle) = new(hdr, Ptr{Void}(0)) + KVStore(hdr::MX_KVStoreHandle) = new(hdr, Ptr{Void}(0)) end function KVStore(kv_type::Symbol = :local) @@ -31,16 +31,15 @@ function _flatten_kvlist(keys :: Vector{Int}, vals :: Vector{<:Vector{<:NDArray} return (keys_flt, vals_flt) end -function init!(self :: KVStore, key :: Int, val :: NDArray) - init!(self, [key], [val]) -end -function init!(self :: KVStore, key :: Int, vals :: Vector{<:NDArray}) - init!(self, Base.ones(Int, length(vals))*key, vals) -end -function init!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{Vector{<:NDArray}}) +init!(self::KVStore, key::Int, val::NDArray) = init!(self, [key], [val]) + +init!(self::KVStore, key::Int, vals::Vector{<:NDArray}) = + init!(self, Base.ones(Int, length(vals)) * key, vals) + +init!(self::KVStore, keys::Vector{Int}, vals::Vector{<:Vector{<:NDArray}}) = init!(self, _flatten_kvlist(keys, vals)...) -end -function init!(self :: KVStore, keys :: Vector{Int}, vals :: Vector{<:NDArray}) + +function init!(self::KVStore, keys::Vector{Int}, vals::Vector{<:NDArray}) @assert length(keys) == length(vals) keys = Cint[keys...] vals = MX_handle[vals...] From 15712558367b7c144600fa53b7166496c72363f4 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 29 Nov 2017 16:49:30 +0800 Subject: [PATCH 37/37] model: style stuff --- src/model.jl | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/model.jl b/src/model.jl index 886c92a8d..06b7a2cf8 100644 --- a/src/model.jl +++ b/src/model.jl @@ -21,7 +21,7 @@ mutable struct FeedForward <: AbstractModel arg_params :: Dict{Symbol} aux_params :: Dict{Symbol} - pred_exec :: Union{Executor, Void} + pred_exec :: Union{Executor,Void} # leave the rest fields undefined FeedForward(arch::SymbolicNode, ctx::Vector{Context}) = new(arch, ctx) @@ -33,7 +33,7 @@ Get a split of `batch_size` into `n_split` pieces for data parallelization. Retu of length `n_split`, with each entry a `UnitRange{Int}` indicating the slice index for that piece. """ -function _split_inputs(batch_size :: Int, n_split :: Int) +function _split_inputs(batch_size::Int, n_split::Int) @assert(batch_size >= n_split) per_split = floor(Int, batch_size / n_split) counts = Base.zeros(Int, n_split)+per_split @@ -73,7 +73,7 @@ weights. * `input_shapes`: the shape of all data and label inputs to this model, given as keyword arguments. For example, `data=(28,28,1,100), label=(100,)`. """ -function init_model(self :: FeedForward, initializer :: AbstractInitializer; overwrite::Bool=false, input_shapes...) +function init_model(self::FeedForward, initializer::AbstractInitializer; overwrite::Bool=false, input_shapes...) # all arg names, including data, label, and parameters arg_names = list_arguments(self.arch) @@ -138,7 +138,7 @@ function init_model(self :: FeedForward, initializer :: AbstractInitializer; ove return (arg_names, param_names, aux_names) end -function _setup_predictor(self :: FeedForward, overwrite :: Bool=false; verbosity :: Integer = 1, data_shapes...) +function _setup_predictor(self::FeedForward, overwrite::Bool=false; verbosity::Integer = 1, data_shapes...) if !isdefined(self, :pred_exec) || isa(self.pred_exec, Void) || overwrite if !isdefined(self, :arg_params) || !isdefined(self, :aux_params) @assert(false, "Model weights not defined, please init or train the model, or load from file") @@ -202,12 +202,12 @@ end See also [`train`](@ref), [`fit`](@ref), [`init_model`](@ref), and [`load_checkpoint`](@ref) """ -function predict(callback :: Function, self :: FeedForward, data :: AbstractDataProvider; - overwrite :: Bool = true, verbosity :: Integer = 1) +function predict(callback::Function, self::FeedForward, data::AbstractDataProvider; + overwrite::Bool = true, verbosity::Integer = 1) predict(self, data; overwrite = overwrite, callback=callback, verbosity = verbosity) end -function predict(self :: FeedForward, data :: AbstractDataProvider; - overwrite::Bool=true, callback::Union{Function,Void}=nothing, verbosity :: Integer = 1) +function predict(self::FeedForward, data::AbstractDataProvider; + overwrite::Bool = true, callback::Union{Function,Void}=nothing, verbosity::Integer = 1) data_shapes = provide_data(data) data_names = [x[1] for x in data_shapes] _setup_predictor(self, overwrite; verbosity = verbosity, data_shapes...) @@ -255,11 +255,13 @@ function predict(self :: FeedForward, data :: AbstractDataProvider; return output_arrays end -function _init_model(self :: FeedForward, data :: AbstractDataProvider, initializer :: AbstractInitializer, overwrite :: Bool) - init_model(self, initializer; overwrite=overwrite, [provide_data(data)..., provide_label(data)...]...) +function _init_model(self::FeedForward, data::AbstractDataProvider, + initializer::AbstractInitializer, overwrite::Bool) + init_model(self, initializer; overwrite=overwrite, + [provide_data(data)..., provide_label(data)...]...) end -function _create_kvstore(kv_type::Symbol, num_device::Int, arg_params::Dict{Symbol}, verbosity :: Int) +function _create_kvstore(kv_type::Symbol, num_device::Int, arg_params::Dict{Symbol}, verbosity::Int) if num_device == 1 && !ismatch(r"dist", string(kv_type)) return nothing else @@ -289,7 +291,7 @@ end function _invoke_callbacks(self::FeedForward, callbacks::Vector{AbstractCallback}, state::OptimizationState, type_filter::Type; - metric::Vector{Tuple{Symbol, T}} = Vector{Tuple{Symbol, Real}}()) where T<:Real + metric::Vector{Tuple{Symbol,T}} = Vector{Tuple{Symbol,Real}}()) where T<:Real map(callbacks) do cb if isa(cb, type_filter) if type_filter == AbstractEpochCallback @@ -342,7 +344,8 @@ Train the `model` on `data` with the `optimizer`. - `2`: Print one time messages and a message at the start of each epoch - `3`: Print a summary of the training and validation accuracy for each epoch """ -function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: AbstractDataProvider; kwargs...) +function fit(self::FeedForward, optimizer::AbstractOptimizer, data::AbstractDataProvider; + kwargs...) opts = TrainingOptions(; kwargs...) opts.verbosity >= 1 && info("Start training on $(self.ctx)") @@ -379,7 +382,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra freeze_idx = filter(i -> in(param_names[i], freeze_names), 1:length(param_names)) # Setup grad_req as a dictionary - grad_req = Dict{Symbol, GRAD_REQ}() + grad_req = Dict{Symbol,GRAD_REQ}() for param in param_names if in(param, freeze_names) grad_req[param] = GRAD_NOP @@ -627,8 +630,8 @@ function load_checkpoint(prefix::AbstractString, epoch::Int, ::Type{FeedForward} return model end -function load_checkpoint(self :: FeedForward, prefix :: AbstractString, epoch :: Int; - overwrite :: Bool = true, allow_different_arch :: Bool = false) +function load_checkpoint(self::FeedForward, prefix::AbstractString, epoch::Int; + overwrite::Bool = true, allow_different_arch::Bool = false) if isdefined(self, :arg_params) && isdefined(self, :aux_params) && !overwrite info("model weights already exists, skip loading... (call with overwrite=true if needed)") return self