From 397bde0536004153728754da2fcacb6ad369f946 Mon Sep 17 00:00:00 2001 From: Tim Besard <tim.besard@gmail.com> Date: Tue, 17 Nov 2020 12:55:56 +0100 Subject: [PATCH] Don't use explicit per-stream threads. (#551) We generally already use the ptsd/ptsz API calls, so the default stream equals the per-thread one. --- lib/cudadrv/stream.jl | 9 +++++++++ src/array.jl | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lib/cudadrv/stream.jl b/lib/cudadrv/stream.jl index 16df97d8f7..43abb7f9a8 100644 --- a/lib/cudadrv/stream.jl +++ b/lib/cudadrv/stream.jl @@ -55,6 +55,10 @@ Return the default stream. CuStreamLegacy() Return a special object to use use an implicit stream with legacy synchronization behavior. + +You can use this stream to perform operations that should block on all streams (with the +exception of streams created with `CU_STREAM_NON_BLOCKING`). This matches the old pre-CUDA 7 +global stream behavior. """ @inline CuStreamLegacy() = CuStream(convert(CUstream, 1), CuContext(C_NULL)) @@ -62,6 +66,11 @@ Return a special object to use use an implicit stream with legacy synchronizatio CuStreamPerThread() Return a special object to use an implicit stream with per-thread synchronization behavior. + +This should generally only be used with compiled libraries, which cannot be switched to the +per-thread API calls. For all other uses, it be libraries compiled with `nvcc +--default-stream per-thread` or any CUDA API call using CUDA.jl (which defaults to the +per-thread variants) you can just use the default `CuDefaultStream` object. """ @inline CuStreamPerThread() = CuStream(convert(CUstream, 2), CuContext(C_NULL)) diff --git a/src/array.jl b/src/array.jl index d63daf3c3f..ac1e9ca16c 100644 --- a/src/array.jl +++ b/src/array.jl @@ -306,7 +306,7 @@ end function Base.unsafe_copyto!(dest::DenseCuArray{T}, doffs, src::DenseCuArray{T}, soffs, n) where T GC.@preserve src dest unsafe_copyto!(pointer(dest, doffs), pointer(src, soffs), n; - async=true, stream=CuStreamPerThread()) + async=true, stream=CuDefaultStream()) if Base.isbitsunion(T) # copy selector bytes error("Not implemented")