diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 7f1281b544b3..e7c853219936 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -845,7 +845,7 @@ def test_randint(): 'high': 3, 'shape' : (500, 500), 'dtype' : dtype, - 'ctx' : mx.context.cpu() + 'ctx' : mx.context.current_context() } mx.random.seed(128) ret1 = mx.nd.random.randint(**params).asnumpy() @@ -856,12 +856,12 @@ def test_randint(): @with_seed() def test_randint_extremes(): - a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.cpu()) + a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.current_context()) assert a>=50000000 and a<=50000010 @with_seed() def test_randint_generator(): - ctx = mx.context.cpu() + ctx = mx.context.current_context() for dtype in ['int32', 'int64']: for low, high in [(50000000, 50001000),(-50000000,-9900),(-500,199),(-2147483647,2147483647)]: scale = high - low