From 397bde0536004153728754da2fcacb6ad369f946 Mon Sep 17 00:00:00 2001
From: Tim Besard <tim.besard@gmail.com>
Date: Tue, 17 Nov 2020 12:55:56 +0100
Subject: [PATCH] Don't use explicit per-stream threads. (#551)

We generally already use the ptsd/ptsz API calls, so the default stream
equals the per-thread one.
---
 lib/cudadrv/stream.jl | 9 +++++++++
 src/array.jl          | 2 +-
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/lib/cudadrv/stream.jl b/lib/cudadrv/stream.jl
index 16df97d8f7..43abb7f9a8 100644
--- a/lib/cudadrv/stream.jl
+++ b/lib/cudadrv/stream.jl
@@ -55,6 +55,10 @@ Return the default stream.
     CuStreamLegacy()
 
 Return a special object to use use an implicit stream with legacy synchronization behavior.
+
+You can use this stream to perform operations that should block on all streams (with the
+exception of streams created with `CU_STREAM_NON_BLOCKING`). This matches the old pre-CUDA 7
+global stream behavior.
 """
 @inline CuStreamLegacy() = CuStream(convert(CUstream, 1), CuContext(C_NULL))
 
@@ -62,6 +66,11 @@ Return a special object to use use an implicit stream with legacy synchronizatio
     CuStreamPerThread()
 
 Return a special object to use an implicit stream with per-thread synchronization behavior.
+
+This should generally only be used with compiled libraries, which cannot be switched to the
+per-thread API calls. For all other uses, it be libraries compiled with `nvcc
+--default-stream per-thread` or any CUDA API call using CUDA.jl (which defaults to the
+per-thread variants) you can just use the default `CuDefaultStream` object.
 """
 @inline CuStreamPerThread() = CuStream(convert(CUstream, 2), CuContext(C_NULL))
 
diff --git a/src/array.jl b/src/array.jl
index d63daf3c3f..ac1e9ca16c 100644
--- a/src/array.jl
+++ b/src/array.jl
@@ -306,7 +306,7 @@ end
 
 function Base.unsafe_copyto!(dest::DenseCuArray{T}, doffs, src::DenseCuArray{T}, soffs, n) where T
   GC.@preserve src dest unsafe_copyto!(pointer(dest, doffs), pointer(src, soffs), n;
-                                       async=true, stream=CuStreamPerThread())
+                                       async=true, stream=CuDefaultStream())
   if Base.isbitsunion(T)
     # copy selector bytes
     error("Not implemented")