Skip to content

Commit

Permalink
ndarray: fix add broadcasting and more tests (apache#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Nov 8, 2017
1 parent 57cc677 commit eea128a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,11 @@ macro inplace(stmt)
end

"""
add_to!(dst :: NDArray, args :: Union{Real, NDArray}...)
add_to!(dst::NDArray, args::NDArrayOrReal...)
Add a bunch of arguments into `dst`. Inplace updating.
"""
function add_to!(dst :: NDArray, args :: Union{Real, NDArray}...)
function add_to!(dst::NDArray, args::NDArrayOrReal...)
@assert dst.writable
for arg in args
if isa(arg, Real)
Expand All @@ -567,17 +567,18 @@ Summation. Multiple arguments of either scalar or `NDArray` could be
added together. Note at least the first or second argument needs to be an
`NDArray` to avoid ambiguity of built-in summation.
"""
+(x::NDArray, ys::NDArrayOrReal...) = add_to!(copy(x, context(x)), ys...)
+(x::NDArray, ys::NDArrayOrReal...) = add_to!(copy(x, context(x)), ys...)
+(x::Real, y::NDArray, zs::NDArrayOrReal...) = add_to!(copy(y, context(y)), x, zs...)

broadcast_(::typeof(+), x::NDArray, y::NDArrayOrReal) = x + y
broadcast_(::typeof(+), x::Real, y::NDArray) = x + y

"""
sub_from!(dst :: NDArray, args :: Union{Real, NDArray}...)
sub_from!(dst::NDArray, args::NDArrayOrReal...)
Subtract a bunch of arguments from `dst`. Inplace updating.
"""
function sub_from!(dst :: NDArray, arg :: Union{Real, NDArray})
function sub_from!(dst::NDArray, arg::NDArrayOrReal)
@assert dst.writable
if isa(arg, Real)
_minus_scalar(dst, scalar=convert(eltype(dst), arg), out=dst)
Expand All @@ -604,12 +605,12 @@ broadcast_(::typeof(-), x::NDArray, y::NDArrayOrReal) = x - y
broadcast_(::typeof(-), x::Real, y::NDArray) = x - y

"""
mul_to!(dst :: NDArray, arg :: Union{Real, NDArray})
mul_to!(dst::NDArray, arg::NDArrayOrReal)
Elementwise multiplication into `dst` of either a scalar or an `NDArray` of the same shape.
Inplace updating.
"""
function mul_to!(dst :: NDArray, arg :: Union{Real, NDArray})
function mul_to!(dst::NDArray, arg::NDArrayOrReal)
@assert dst.writable
if isa(arg, Real)
_mul_scalar(dst, scalar=convert(eltype(dst), arg), out=dst)
Expand Down Expand Up @@ -644,11 +645,11 @@ function *(x::NDArray, y::NDArray)
end

"""
div_from!(dst :: NDArray, arg :: Union{Real, NDArray})
div_from!(dst::NDArray, arg::NDArrayOrReal)
Elementwise divide a scalar or an `NDArray` of the same shape from `dst`. Inplace updating.
"""
function div_from!(dst :: NDArray, arg :: Union{Real, NDArray})
function div_from!(dst::NDArray, arg::NDArrayOrReal)
@assert dst.writable
if isa(arg, Real)
_div_scalar(dst, scalar=convert(eltype(dst), arg), out=dst)
Expand Down
6 changes: 6 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ function test_plus()
scalar_large = Float16(1e4)
@test reldiff(t6 + scalar_small, copy(a6 .+ scalar_small)) < 1e-1
@test reldiff(t6 + scalar_large, copy(a6 .+ scalar_large)) < 1e-1

let x = mx.NDArray([1 2; 3 4]), y = mx.NDArray([1 1; 1 1])
@test copy(42 .+ x) == [43 44; 45 46]
@test copy(x .+ 42) == [43 44; 45 46]
@test copy(0 .+ x .+ y .+ 41) == [43 44; 45 46]
end
end

function test_minus()
Expand Down

0 comments on commit eea128a

Please sign in to comment.