Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Julia: deprecate mx.empty, replace it with UndefInitializer constructor #13934

Merged
merged 1 commit into from
Jan 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions julia/NEWS.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# v0.4.0 (#TBD)
# v1.5.0 (#TBD)

* Following material from `mx` module got exported (#TBD):
* `NDArray`
* `clip()`
* `clip!()`
* `context()`
* `empty()`
* `expand_dims()`
* `@inplace`
* `σ()`
Expand Down Expand Up @@ -113,6 +112,16 @@
3.0
```

* `mx.empty` is deprecated and replaced by `UndefInitializer` constructor. (#TBD)

E.g.
```julia
julia> NDArray(undef, 2, 5)
2×5 NDArray{Float32,2} @ CPU0:
-21260.344f0 1.674986f19 0.00016893122f0 1.8363f-41 0.0f0
3.0763f-41 1.14321726f27 4.24219f-8 0.0f0 0.0f0
```

* A port of Python's `autograd` for `NDArray` (#274)

* `size(x, dims...)` is supported now. (#TBD)
Expand Down
18 changes: 11 additions & 7 deletions julia/docs/src/user-guide/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,13 @@ operators in Julia directly.

The followings are common ways to create `NDArray` objects:

- `mx.empty(shape[, context])`: create on uninitialized array of a
given shape on a specific device. For example,
`mx.empty(2, 3)`, `mx.((2, 3), mx.gpu(2))`.
- `NDArray(undef, shape...; ctx = context, writable = true)`:
create an uninitialized array of a given shape on a specific device.
For example,
`NDArray(undef, 2, 3)`, `NDArray(undef, 2, 3, ctx = mx.gpu(2))`.
- `NDArray(undef, shape; ctx = context, writable = true)`
- `NDArray{T}(undef, shape...; ctx = context, writable = true)`:
create an uninitialized with the given type `T`.
- `mx.zeros(shape[, context])` and `mx.ones(shape[, context])`:
similar to the Julia's built-in `zeros` and `ones`.
- `mx.copy(jl_arr, context)`: copy the contents of a Julia `Array` to
Expand All @@ -101,11 +105,11 @@ shows a way to set the contents of an `NDArray`.
```@repl
using MXNet
mx.srand(42)
a = mx.empty(2, 3)
a = NDArray(undef, 2, 3)
a[:] = 0.5 # set all elements to a scalar
a[:] = rand(size(a)) # set contents with a Julia Array
copy!(a, rand(size(a))) # set value by copying a Julia Array
b = mx.empty(size(a))
b = NDArray(undef, size(a))
b[:] = a # copying and assignment between NDArrays
```

Expand Down Expand Up @@ -175,7 +179,7 @@ function inplace_op()
grad = mx.ones(SHAPE, CTX)

# pre-allocate temp objects
grad_lr = mx.empty(SHAPE, CTX)
grad_lr = NDArray(undef, SHAPE, ctx = CTX)

for i = 1:N_REP
copy!(grad_lr, grad)
Expand Down Expand Up @@ -234,7 +238,7 @@ shape = (2, 3)
key = 3

mx.init!(kv, key, mx.ones(shape) * 2)
a = mx.empty(shape)
a = NDArray(undef, shape)
mx.pull!(kv, key, a) # pull value into a
a
```
Expand Down
1 change: 0 additions & 1 deletion julia/src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ export NDArray,
clip,
clip!,
context,
empty,
expand_dims,
@inplace,
# activation funcs
Expand Down
25 changes: 25 additions & 0 deletions julia/src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,28 @@ import Base: sum, maximum, minimum, prod, cat

import Statistics: mean
@deprecate mean(x::NDArray, dims) mean(x, dims = dims)

# replaced by UndefInitializer
function empty(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
@warn("`mx.empty(T, dims, ctx)` is deprecated, " *
"use `NDArray{T,N}(undef, dims; ctx = ctx)` instead.")
NDArray{T,N}(undef, dims; ctx = ctx)
end

function empty(::Type{T}, dims::Int...) where {T<:DType}
@warn("`mx.empty(T, dims...)` is deprecated, " *
"use `NDArray{T,N}(undef, dims...)` instead.")
NDArray{T,N}(undef, dims...)
end

function empty(dims::NTuple{N,Int}, ctx::Context = cpu()) where N
@warn("`mx.empty(dims, ctx)` is deprecated, " *
"use `NDArray(undef, dims; ctx = ctx)` instead.")
NDArray(undef, dims; ctx = ctx)
end

function empty(dims::Int...)
@warn("`mx.empty(dims...)` is deprecated, " *
"use `NDArray(undef, dims...)` instead.")
NDArray(undef, dims...)
end
2 changes: 1 addition & 1 deletion julia/src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ function ArrayDataProvider(data, label; batch_size::Int = 0, shuffle::Bool = fal
function gen_batch_nds(arrs :: Vector{Array{MX_float}}, bsize :: Int)
map(arrs) do arr
shape = size(arr)
empty(shape[1:end-1]..., bsize)
NDArray(undef, shape[1:end-1]..., bsize)
end
end

Expand Down
8 changes: 4 additions & 4 deletions julia/src/kvstore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ One can use ``barrier()`` to sync all workers.
julia> kv = KVStore(:local)
mx.KVStore @ local

julia> x = mx.empty(2, 3);
julia> x = NDArray(undef, 2, 3);

julia> init!(kv, 3, x)

Expand Down Expand Up @@ -161,11 +161,11 @@ julia> x
```jldoctest
julia> keys = [4, 5];

julia> init!(kv, keys, [empty(2, 3), empty(2, 3)])
julia> init!(kv, keys, [NDArray(undef, 2, 3), NDArray(undef, 2, 3)])

julia> push!(kv, keys, [x, x])

julia> y, z = empty(2, 3), empty(2, 3);
julia> y, z = NDArray(undef, 2, 3), NDArray(undef, 2, 3);

julia> pull!(kv, keys, [y, z])
```
Expand Down Expand Up @@ -279,7 +279,7 @@ julia> init!(kv, 42, mx.ones(2, 3))

julia> push!(kv, 42, mx.ones(2, 3))

julia> x = empty(2, 3);
julia> x = NDArray(undef, 2, 3);

julia> pull!(kv, 42, x)

Expand Down
8 changes: 4 additions & 4 deletions julia/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function init_model(self::FeedForward, initializer::AbstractInitializer; overwri
delete!(self.arg_params, name)
end
end
arg_params[name] = empty(shape)
arg_params[name] = NDArray(undef, shape)
end

for (name, shape) in zip(aux_names, aux_shapes)
Expand All @@ -135,7 +135,7 @@ function init_model(self::FeedForward, initializer::AbstractInitializer; overwri
delete!(self.aux_params, name)
end
end
aux_params[name] = empty(shape)
aux_params[name] = NDArray(undef, shape)
end

for (k,v) in arg_params
Expand Down Expand Up @@ -463,8 +463,8 @@ function fit(self::FeedForward, optimizer::AbstractOptimizer, data::AbstractData
# set up output and labels in CPU for evaluation metric
output_shapes = [tuple(size(x)[1:end-1]...,batch_size) for x in train_execs[1].outputs]
cpu_dev = Context(CPU)
cpu_output_arrays = [empty(shape, cpu_dev) for shape in output_shapes]
cpu_label_arrays = [empty(shape, cpu_dev) for (name,shape) in provide_label(data)]
cpu_output_arrays = [NDArray(undef, shape, ctx = cpu_dev) for shape in output_shapes]
cpu_label_arrays = [NDArray(undef, shape, ctx = cpu_dev) for (name,shape) in provide_label(data)]

# invoke callbacks on epoch 0
_invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback)
Expand Down
74 changes: 37 additions & 37 deletions julia/src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,47 @@ mutable struct NDArray{T,N}
handle :: MX_NDArrayHandle
writable :: Bool

NDArray{T,N}(handle, writable = true) where {T,N} = new(handle, writable)
NDArray{T,N}(handle::MX_NDArrayHandle, writable::Bool = true) where {T,N} =
new(handle, writable)
end

# UndefInitializer constructors
NDArray{T,N}(::UndefInitializer, dims::NTuple{N,Integer};
writable = true, ctx::Context = cpu()) where {T,N} =
NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false), writable)
NDArray{T,N}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} =
NDArray{T,N}(undef, dims; kw...)

NDArray{T}(::UndefInitializer, dims::NTuple{N,Integer}; kw...) where {T,N} =
NDArray{T,N}(undef, dims; kw...)
NDArray{T}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} =
NDArray{T,N}(undef, dims; kw...)

NDArray(::UndefInitializer, dims::NTuple{N,Integer}; kw...) where {N} =
NDArray{DEFAULT_DTYPE,N}(undef, dims; kw...)
NDArray(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {N} =
NDArray{DEFAULT_DTYPE,N}(undef, dims; kw...)

NDArray(x::AbstractArray{<:DType}) = copy(collect(x), cpu())
NDArray(x::Array{<:DType}) = copy(x, cpu())

NDArray(::Type{T}, x::AbstractArray) where {T<:DType} =
copy(convert(AbstractArray{T}, x), cpu())

NDArray(handle, writable = true) =
NDArray{eltype(handle), ndims(handle)}(handle, writable)

# type aliases
const NDArrayOrReal = Union{NDArray,Real}
const VecOfNDArray = AbstractVector{<:NDArray}

Base.unsafe_convert(::Type{MX_handle}, x::NDArray) =
Base.unsafe_convert(MX_handle, x.handle)
Base.convert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x)
Base.cconvert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x)

MX_handle(x::NDArray) = Base.convert(MX_handle, x)

function Base.show(io::IO, x::NDArray)
print(io, "NDArray(")
Base.show(io, try_get_shared(x, sync = :read))
Expand All @@ -139,13 +166,6 @@ function Base.show(io::IO, ::MIME{Symbol("text/plain")}, x::NDArray{T,N}) where
Base.print_array(io, try_get_shared(x, sync = :read))
end

Base.unsafe_convert(::Type{MX_handle}, x::NDArray) =
Base.unsafe_convert(MX_handle, x.handle)
Base.convert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x)
Base.cconvert(T::Type{MX_handle}, x::NDArray) = Base.unsafe_convert(T, x)

MX_handle(x::NDArray) = Base.convert(MX_handle, x)

################################################################################
# NDArray functions exported to the users
################################################################################
Expand All @@ -163,34 +183,14 @@ function context(x::NDArray)
end

"""
empty(DType, dims[, ctx::Context = cpu()])
empty(DType, dims)
empty(DType, dim1, dim2, ...)

Allocate memory for an uninitialized `NDArray` with a specified type.
"""
empty(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType} =
NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false))
empty(::Type{T}, dims::Int...) where {T<:DType} = empty(T, dims)

"""
empty(dims::Tuple[, ctx::Context = cpu()])
empty(dim1, dim2, ...)

Allocate memory for an uninitialized `NDArray` with specific shape of type Float32.
"""
empty(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
NDArray(_ndarray_alloc(dims, ctx, false))
empty(dims::Int...) = empty(dims)

"""
similar(x::NDArray)
similar(x::NDArray; writable, ctx)

Create an `NDArray` with similar shape, data type,
and context with the given one.
Note that the returned `NDArray` is uninitialized.
"""
Base.similar(x::NDArray{T}) where {T} = empty(T, size(x), context(x))
Base.similar(x::NDArray{T,N}; writable = x.writable, ctx = context(x)) where {T,N} =
NDArray{T,N}(undef, size(x)...; writable = writable, ctx = ctx)

"""
zeros([DType], dims, [ctx::Context = cpu()])
Expand All @@ -200,7 +200,7 @@ Base.similar(x::NDArray{T}) where {T} = empty(T, size(x), context(x))
Create zero-ed `NDArray` with specific shape and type.
"""
function zeros(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
x = empty(T, dims, ctx)
x = NDArray{T}(undef, dims..., ctx = ctx)
x[:] = zero(T)
x
end
Expand All @@ -222,7 +222,7 @@ Base.zeros(x::NDArray)::typeof(x) = zeros_like(x)
Create an `NDArray` with specific shape & type, and initialize with 1.
"""
function ones(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
arr = empty(T, dims, ctx)
arr = NDArray{T}(undef, dims..., ctx = ctx)
arr[:] = one(T)
arr
end
Expand Down Expand Up @@ -504,10 +504,10 @@ copy(x::NDArray{T,D}, ctx::Context) where {T,D} =

# Create copy: Julia Array -> NDArray in a given context
copy(x::Array{T}, ctx::Context) where {T<:DType} =
copy!(empty(T, size(x), ctx), x)
copy!(NDArray{T}(undef, size(x); ctx = ctx), x)

copy(x::AbstractArray, ctx::Context) =
copy!(empty(eltype(x), size(x), ctx), collect(x))
copy!(NDArray{eltype(x)}(undef, size(x); ctx = ctx), collect(x))

"""
convert(::Type{Array{<:Real}}, x::NDArray)
Expand Down Expand Up @@ -866,8 +866,8 @@ end

Create an `NDArray` filled with the value `x`, like `Base.fill`.
"""
function fill(x, dims::NTuple{N,Integer}, ctx::Context=cpu()) where N
arr = empty(typeof(x), dims, ctx)
function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = cpu()) where {T,N}
arr = NDArray{T}(undef, dims, ctx = ctx)
arr[:] = x
arr
end
Expand Down
10 changes: 5 additions & 5 deletions julia/src/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Samples are uniformly distributed over the half-open interval [low, high)
(includes low, but excludes high).

```julia
julia> mx.rand!(empty(2, 3))
julia> mx.rand!(NDArray(undef, 2, 3))
2×3 mx.NDArray{Float32,2} @ CPU0:
0.385748 0.839275 0.444536
0.0879585 0.215928 0.104636

julia> mx.rand!(empty(2, 3), low = 1, high = 10)
julia> mx.rand!(NDArray(undef, 2, 3), low = 1, high = 10)
2×3 mx.NDArray{Float32,2} @ CPU0:
6.6385 4.18888 2.07505
8.97283 2.5636 1.95586
Expand Down Expand Up @@ -56,8 +56,8 @@ julia> mx.rand(2, 2; low = 1, high = 10)
9.81258 3.58068
```
"""
rand(dims::Int...; low = 0, high = 1, context = cpu()) =
rand!(empty(dims, context), low = low, high = high)
rand(dims::Integer...; low = 0, high = 1, context = cpu()) =
rand!(NDArray(undef, dims, ctx = context), low = low, high = high)

"""
randn!(x::NDArray; μ = 0, σ = 1)
Expand All @@ -73,7 +73,7 @@ randn!(x::NDArray; μ = 0, σ = 1) =
Draw random samples from a normal (Gaussian) distribution.
"""
randn(dims::Int...; μ = 0, σ = 1, context = cpu()) =
randn!(empty(dims, context), μ = μ, σ = σ)
randn!(NDArray(undef, dims, ctx = context), μ = μ, σ = σ)

"""
seed!(seed::Int)
Expand Down
8 changes: 4 additions & 4 deletions julia/test/unittest/bind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ function test_arithmetic(::Type{T}, uf, gf) where T <: mx.DType
ret = uf(lhs, rhs)
@test mx.list_arguments(ret) == [:lhs, :rhs]

lhs_arr = mx.NDArray(rand(T, shape))
rhs_arr = mx.NDArray(rand(T, shape))
lhs_grad = mx.empty(T, shape)
rhs_grad = mx.empty(T, shape)
lhs_arr = NDArray(rand(T, shape))
rhs_arr = NDArray(rand(T, shape))
lhs_grad = NDArray{T}(undef, shape)
rhs_grad = NDArray{T}(undef, shape)

exec2 = mx.bind(ret, mx.Context(mx.CPU), [lhs_arr, rhs_arr], args_grad=[lhs_grad, rhs_grad])
exec3 = mx.bind(ret, mx.Context(mx.CPU), [lhs_arr, rhs_arr])
Expand Down
4 changes: 2 additions & 2 deletions julia/test/unittest/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ function test_mnist()
n_batch = 0
for batch in mnist_provider
if n_batch == 0
data_array = mx.empty(28,28,1,batch_size)
label_array = mx.empty(batch_size)
data_array = NDArray(undef, 28, 28, 1, batch_size)
label_array = NDArray(undef, batch_size)
# have to use "for i=1:1" to get over the legacy "feature" of using
# [ ] to do concatenation in Julia
data_targets = [[(1:batch_size, data_array)] for i = 1:1]
Expand Down
2 changes: 1 addition & 1 deletion julia/test/unittest/kvstore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function test_single_kv_pair()

kv = init_kv()
mx.push!(kv, 3, mx.ones(SHAPE))
val = mx.empty(SHAPE)
val = NDArray(undef, SHAPE)
mx.pull!(kv, 3, val)
@test maximum(abs.(copy(val) .- 1)) == 0
end
Expand Down
Loading