Skip to content

Commit

Permalink
julia: implement context.num_gpus (apache#16236)
Browse files Browse the repository at this point in the history
MXNET-1427 #resolve
  • Loading branch information
iblislin authored and larroy committed Sep 28, 2019
1 parent ce948fe commit 709b9f3
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 1 deletion.
3 changes: 2 additions & 1 deletion julia/src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ export Executor,
# context.jl
export Context,
cpu,
gpu
gpu,
num_gpus

# model.jl
export AbstractModel,
Expand Down
11 changes: 11 additions & 0 deletions julia/src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,14 @@ Get a GPU context with a specific id. The K GPUs on a node is typically numbered
* `dev_id::Integer = 0` the GPU device id.
"""
gpu(dev_id::Integer = 0) = Context(GPU, dev_id)

"""
num_gpus()
Query CUDA for the number of GPUs present.
"""
function num_gpus()
n = Ref{Cint}()
@mxcall :MXGetGPUCount (Ref{Cint},) n
n[]
end
34 changes: 34 additions & 0 deletions julia/test/unittest/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

module TestContext

using MXNet
using Test

function test_num_gpus()
@info "Context::num_gpus"

@test num_gpus() >= 0
end

@testset "Context Test" begin
test_num_gpus()
end


end # module TestContext
2 changes: 2 additions & 0 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def num_gpus():
check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
return count.value


def gpu_memory_info(device_id=0):
"""Query CUDA for the free and total bytes of GPU global memory.
Expand All @@ -300,6 +301,7 @@ def gpu_memory_info(device_id=0):
check_call(_LIB.MXGetGPUMemoryInformation64(dev_id, ctypes.byref(free), ctypes.byref(total)))
return (free.value, total.value)


def current_context():
"""Returns the current context.
Expand Down

0 comments on commit 709b9f3

Please sign in to comment.