diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh index dfb09d632ac8..1e24ccf4b3f7 100644 --- a/src/operator/tensor/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -188,6 +188,15 @@ struct FillRspValsKernel { } }; +template +inline mshadow::Tensor AllocateTempDataForCast(const OpContext& op_ctx, + const mshadow::Shape& shape) { + Resource rsc = ResourceManager::Get()->Request(op_ctx.run_ctx.ctx, + ResourceRequest(ResourceRequest::kTempSpace)); + mshadow::Stream *stream = op_ctx.run_ctx.get_stream(); + return rsc.get_space_typed(shape, stream); +}; + /*! * \brief GPU implementation of casting a dns tensor to rsp type. */ @@ -226,8 +235,8 @@ inline void CastStorageDnsRspImpl(const OpContext& ctx, mshadow::Stream::GetStream(s)); // Allocate temp storage for marking non-zero rows and for cub's prefix sum - mshadow::Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(num_rows*sizeof(RType)+temp_storage_bytes), s); + auto workspace = AllocateTempDataForCast(ctx, Shape1(num_rows*sizeof(RType) + + temp_storage_bytes)); row_flg = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + num_rows*sizeof(RType); @@ -633,8 +642,8 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, mshadow::Stream::GetStream(s)); // Allocate temporary storage - mshadow::Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(temp_storage_bytes), s); + auto workspace = AllocateTempDataForCast(ctx, Shape1(temp_storage_bytes)); + d_temp_storage = workspace.dptr_; // Compute indptr through inclusive prefix sum