Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1440] julia: porting current_context (#17142)
Browse files Browse the repository at this point in the history
* julia: porting `current_context`

- And introduce new macros for changing default context
  `@context`, `@gpu`, `@cpu`
  • Loading branch information
iblislin authored Dec 24, 2019
1 parent 5aa3a7a commit efc4ad8
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 12 deletions.
6 changes: 5 additions & 1 deletion julia/src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ export Context,
cpu,
gpu,
num_gpus,
gpu_memory_info
gpu_memory_info,
current_context,
@context,
@cpu,
@gpu

# model.jl
export AbstractModel,
Expand Down
104 changes: 103 additions & 1 deletion julia/src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,89 @@ struct Context
Context(dev_type::CONTEXT_TYPE, dev_id::Integer = 0) = new(dev_type, dev_id)
end

const _default_ctx = Ref{Context}(Context(CPU, 0))

Context(dev_type::Integer, dev_id::Integer = 0) =
Context(convert(CONTEXT_TYPE, dev_type), dev_id)

Base.show(io::IO, ctx::Context) =
print(io, "$(ctx.device_type)$(ctx.device_id)")
print(io, lowercase("$(ctx.device_type)$(ctx.device_id)"))

function _with_context(dev_type::Union{Symbol,Expr}, dev_id, e::Expr)
global _default_ctx
quote
ctx = current_context()
ctx′ = Context($(esc(dev_type)), $(esc(dev_id)))
$_default_ctx[] = ctx′
try
return $(esc(e))
finally
$_default_ctx[] = ctx
end
end
end

"""
@context device_type [device_id] expr
Change the default context in the following expression.
# Examples
```jl-repl
julia> mx.@context mx.GPU begin
mx.zeros(2, 3)
end
2×3 NDArray{Float32,2} @ gpu0:
0.0f0 0.0f0 0.0f0
0.0f0 0.0f0 0.0f0
julia> @context mx.GPU mx.zeros(3, 2)
3×2 NDArray{Float32,2} @ gpu0:
0.0f0 0.0f0
0.0f0 0.0f0
0.0f0 0.0f0
```
"""
macro context(dev_type, e::Expr)
_with_context(dev_type, 0, e)
end

macro context(dev_type, dev_id, e::Expr)
_with_context(dev_type, dev_id, e)
end

for dev [:cpu, :gpu]
ctx = QuoteNode(Symbol(uppercase(string(dev))))
docstring = """
@$dev [device_id] expr
A shorthand for `@context mx.GPU`.
# Examples
```jl-repl
julia> mx.@with_gpu mx.zeros(2, 3)
2×3 NDArray{Float32,2} @ gpu0:
0.0f0 0.0f0 0.0f0
0.0f0 0.0f0 0.0f0
```
"""
@eval begin
@doc $docstring ->
macro $dev(e::Expr)
ctx = $ctx
quote
@context $ctx $(esc(e))
end
end

macro $dev(dev_id, e::Expr)
ctx = $ctx
quote
@context $ctx $(esc(dev_id)) $(esc(e))
end
end
end
end # for dev ∈ [:cpu, :gpu]

"""
cpu(dev_id)
Expand Down Expand Up @@ -86,3 +164,27 @@ function gpu_memory_info(dev_id = 0)
@mxcall :MXGetGPUMemoryInformation64 (Cint, Ref{UInt64}, Ref{UInt64}) dev_id free n
free[], n[]
end

"""
current_context()
Return the current context.
By default, `mx.cpu()` is used for all the computations
and it can be overridden by using the `@context` macro.
# Examples
```jl-repl
julia> mx.current_context()
cpu0
julia> mx.@context mx.GPU 1 begin # Context changed in the following code block
mx.current_context()
end
gpu1
julia> mx.current_context()
cpu0
```
"""
current_context() = _default_ctx[]
18 changes: 10 additions & 8 deletions julia/src/ndarray/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,45 @@ Base.similar(x::NDArray{T,N}; writable = x.writable, ctx = context(x)) where {T,
NDArray{T,N}(undef, size(x)...; writable = writable, ctx = ctx)

"""
zeros([DType], dims, [ctx::Context = cpu()])
zeros([DType], dims, ctx::Context = current_context())
zeros([DType], dims...)
zeros(x::NDArray)
Create zero-ed `NDArray` with specific shape and type.
"""
function zeros(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
function zeros(::Type{T}, dims::NTuple{N,Int},
ctx::Context = current_context()) where {N,T<:DType}
x = NDArray{T}(undef, dims..., ctx = ctx)
x[:] = zero(T)
x
end

zeros(::Type{T}, dims::Int...) where {T<:DType} = zeros(T, dims)

zeros(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
zeros(dims::NTuple{N,Int}, ctx::Context = current_context()) where N =
zeros(MX_float, dims, ctx)
zeros(dims::Int...) = zeros(dims)

zeros(x::NDArray)::typeof(x) = zeros_like(x)
Base.zeros(x::NDArray)::typeof(x) = zeros_like(x)

"""
ones([DType], dims, [ctx::Context = cpu()])
ones([DType], dims, ctx::Context = current_context())
ones([DType], dims...)
ones(x::NDArray)
Create an `NDArray` with specific shape & type, and initialize with 1.
"""
function ones(::Type{T}, dims::NTuple{N,Int}, ctx::Context = cpu()) where {N,T<:DType}
function ones(::Type{T}, dims::NTuple{N,Int},
ctx::Context = current_context()) where {N,T<:DType}
arr = NDArray{T}(undef, dims..., ctx = ctx)
arr[:] = one(T)
arr
end

ones(::Type{T}, dims::Int...) where T<:DType = ones(T, dims)

ones(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
ones(dims::NTuple{N,Int}, ctx::Context = current_context()) where N =
ones(MX_float, dims, ctx)
ones(dims::Int...) = ones(dims)

Expand Down Expand Up @@ -458,12 +460,12 @@ function Base.fill!(arr::NDArray, x)
end

"""
fill(x, dims, ctx=cpu())
fill(x, dims, ctx = current_context())
fill(x, dims...)
Create an `NDArray` filled with the value `x`, like `Base.fill`.
"""
function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = cpu()) where {T,N}
function fill(x::T, dims::NTuple{N,Integer}, ctx::Context = current_context()) where {T,N}
arr = NDArray{T}(undef, dims, ctx = ctx)
arr[:] = x
arr
Expand Down
2 changes: 1 addition & 1 deletion julia/src/ndarray/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end

# UndefInitializer constructors
NDArray{T,N}(::UndefInitializer, dims::NTuple{N,Integer};
writable = true, ctx::Context = cpu()) where {T,N} =
writable = true, ctx::Context = current_context()) where {T,N} =
NDArray{T,N}(_ndarray_alloc(T, dims, ctx, false), writable)
NDArray{T,N}(::UndefInitializer, dims::Vararg{Integer,N}; kw...) where {T,N} =
NDArray{T,N}(undef, dims; kw...)
Expand Down
77 changes: 77 additions & 0 deletions julia/test/unittest/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,85 @@ function test_num_gpus()
@test num_gpus() >= 0
end

function test_context_macro()
@info "Context::@context"

@context mx.CPU 42 begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 42

@context mx.GPU 24 begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 24
end

ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 42
end

function f()
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 123
end

@context mx.GPU 123 begin
f()
end

@context mx.GPU begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 0
end

@context mx.CPU begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 0
end

@info "Context::@gpu"
@gpu 123 f()
@gpu begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 0
end
let n = 321
@gpu n begin
ctx = mx.current_context()
@test ctx.device_type == mx.GPU
@test ctx.device_id == 321
end
end

@info "Context::@cpu"
@cpu 123 begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 123
end
@cpu begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 0
end
let n = 321
@cpu n begin
ctx = mx.current_context()
@test ctx.device_type == mx.CPU
@test ctx.device_id == 321
end
end
end

@testset "Context Test" begin
test_num_gpus()
test_context_macro()
end


Expand Down
2 changes: 1 addition & 1 deletion julia/test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ function test_show()
@test occursin("1×4", str)
@test occursin("NDArray", str)
@test occursin("Int64", str)
@test occursin("CPU", str)
@test occursin("cpu", str)
@test match(r"1\s+2\s+3\s+4", str) != nothing
end

Expand Down

0 comments on commit efc4ad8

Please sign in to comment.