diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 589503727ce1..9060d6d86dfe 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -525,8 +525,8 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): handle = ExecutorHandle() check_call(_LIB.MXExecutorBind(self.handle, - mx_uint(ctx.device_mask), - mx_uint(ctx.device_id), + ctx.device_typeid, + ctx.device_id, len(args), args_handle, args_grad_handle,