diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 926a5e811946..1c24bb0a64f3 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -1054,6 +1054,10 @@ class NDArray { } size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); if (aux_handles[i].size < aux_bytes) { + // set the context for aux handle to make sure we free it to the right device. + if (aux_handles[i].ctx.dev_type != ctx.dev_type) { + aux_handles[i].ctx = ctx; + } // free storage and alloc again Storage::Get()->Free(aux_handles[i]); aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx);