From 0b45794e704e1ad8ef53166a9383e9b02b2dc9c9 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 3 Aug 2020 22:20:19 +0000 Subject: [PATCH] Fix gpu test --- tests/python/unittest/test_gluon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index da209f473ead..5fa1ca20aabf 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -313,7 +313,7 @@ def hybrid_forward(self, F, x): if compute_before_cast: # Compute before casting to catch bugs where symbol dtype isn't casted correctly GH-18843 net_fp32.initialize() - net_fp32(mx.nd.zeros((1,3,224,224))) + net_fp32(mx.nd.zeros((1,3,224,224)), ctx=ctx) net_fp32.cast('float64') net_fp32.hybridize() data = mx.nd.zeros((1,3,224,224), dtype='float64', ctx=ctx)