Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change ctx to include gpu, randint samepl_op.cu typo
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Nov 20, 2018
1 parent ce3849d commit 494e416
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/operator/random/sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_exponential, SampleExponentialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_poisson, SamplePoissonParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_negative_binomial, SampleNegBinomialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_generalized_negative_binomial, SampleGenNegBinomialParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_int, SampleRandIntParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_randint, SampleRandIntParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_uniform_like, SampleUniformLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_normal_like, SampleNormalLikeParam)
MXNET_OPERATOR_REGISTER_SAMPLE_GPU(_random_gamma_like, SampleGammaLikeParam)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def test_randint():
'high': 3,
'shape' : (500, 500),
'dtype' : dtype,
'ctx' : mx.context.cpu()
'ctx' : mx.context.current_context()
}
mx.random.seed(128)
ret1 = mx.nd.random.randint(**params).asnumpy()
Expand All @@ -856,12 +856,12 @@ def test_randint():

@with_seed()
def test_randint_extremes():
a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.cpu())
a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.current_context())
assert a>=50000000 and a<=50000010

@with_seed()
def test_randint_generator():
ctx = mx.context.cpu()
ctx = mx.context.current_context()
for dtype in ['int32', 'int64']:
for low, high in [(50000000, 50001000),(-50000000,-9900),(-500,199),(-2147483647,2147483647)]:
scale = high - low
Expand Down

0 comments on commit 494e416

Please sign in to comment.