diff --git a/src/resource.cc b/src/resource.cc index 80a5c0e444e1..0317ff32bbf3 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -432,6 +432,9 @@ void Resource::get_cudnn_dropout_desc( // not initialized yet. size_t dropout_state_size; CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size)); + // reserve GPU space + Storage::Get()->DirectFree( + Storage::Get()->Alloc(dropout_state_size, state_space->ctx)); CUDNN_CALL(cudnnSetDropoutDescriptor(*dropout_desc, stream->dnn_handle_, dropout, state_space->GetSpace(dropout_state_size),