Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix test_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Oct 13, 2019
1 parent 7876e7f commit 7a088a9
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .ndarray import array
from .symbol import Symbol
from .symbol.numpy import _Symbol as np_symbol
from .util import use_np # pylint: disable=unused-import
from .runtime import Features
from .numpy_extension import get_cuda_compute_capability

Expand Down Expand Up @@ -2235,12 +2236,16 @@ def has_tvm_ops():
"""Returns True if MXNet is compiled with TVM generated operators. If current ctx
is GPU, it only returns True for CUDA compute capability > 52 where FP16 is supported."""
built_with_tvm_op = _features.is_enabled("TVM_OP")
if current_context().device_type == 'gpu':
try:
import tvm
except ImportError:
return False
return built_with_tvm_op and (int("".join(tvm.nd.gpu(0).compute_version.split('.'))) >= 53)
ctx = current_context()
if ctx.device_type == 'gpu':
try:
cc = get_cuda_compute_capability(ctx)
except:
print('Failed to get CUDA compute capability for context {}. The operators '
'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx))
return False
print('Cuda arch compute capability: sm_{}'.format(str(cc)))
return built_with_tvm_op and cc >= 53
return built_with_tvm_op


Expand Down

0 comments on commit 7a088a9

Please sign in to comment.