diff --git a/base/channels.jl b/base/channels.jl index 206e8d98927bc..b0c8c44667e2e 100644 --- a/base/channels.jl +++ b/base/channels.jl @@ -33,7 +33,7 @@ mutable struct Channel{T} <: AbstractChannel{T} cond_take::Threads.Condition # waiting for data to become available cond_wait::Threads.Condition # waiting for data to become maybe available cond_put::Threads.Condition # waiting for a writeable slot - state::Symbol + @atomic state::Symbol excp::Union{Exception, Nothing} # exception to be thrown when state !== :open data::Vector{T} @@ -167,6 +167,8 @@ isbuffered(c::Channel) = c.sz_max==0 ? false : true function check_channel_state(c::Channel) if !isopen(c) + # if the monotonic load succeed, now do an acquire fence + (@atomic :acquire c.state) === :open && concurrency_violation() excp = c.excp excp !== nothing && throw(excp) throw(closed_exception()) @@ -183,8 +185,8 @@ Close a channel. An exception (optionally given by `excp`), is thrown by: function close(c::Channel, excp::Exception=closed_exception()) lock(c) try - c.state = :closed c.excp = excp + @atomic :release c.state = :closed notify_error(c.cond_take, excp) notify_error(c.cond_wait, excp) notify_error(c.cond_put, excp) @@ -193,7 +195,7 @@ function close(c::Channel, excp::Exception=closed_exception()) end nothing end -isopen(c::Channel) = (c.state === :open) +isopen(c::Channel) = ((@atomic :monotonic c.state) === :open) """ bind(chnl::Channel, task::Task) @@ -339,6 +341,7 @@ function put_buffered(c::Channel, v) check_channel_state(c) wait(c.cond_put) end + check_channel_state(c) push!(c.data, v) did_buffer = true # notify all, since some of the waiters may be on a "fetch" call. @@ -361,6 +364,7 @@ function put_unbuffered(c::Channel, v) notify(c.cond_wait) wait(c.cond_put) end + check_channel_state(c) # unfair scheduled version of: notify(c.cond_take, v, false, false); yield() popfirst!(c.cond_take.waitq) finally diff --git a/base/task.jl b/base/task.jl index b25197e0aadcc..3f4c1d4ef0e64 100644 --- a/base/task.jl +++ b/base/task.jl @@ -40,6 +40,7 @@ struct CompositeException <: Exception end length(c::CompositeException) = length(c.exceptions) push!(c::CompositeException, ex) = push!(c.exceptions, ex) +pushfirst!(c::CompositeException, ex) = pushfirst!(c.exceptions, ex) isempty(c::CompositeException) = isempty(c.exceptions) iterate(c::CompositeException, state...) = iterate(c.exceptions, state...) eltype(::Type{CompositeException}) = Any @@ -353,6 +354,29 @@ end ## lexically-scoped waiting for multiple items +struct ScheduledAfterSyncException <: Exception + values::Vector{Any} +end + +function showerror(io::IO, ex::ScheduledAfterSyncException) + print(io, "ScheduledAfterSyncException: ") + if isempty(ex.values) + print(io, "(no values)") + return + end + show(io, ex.values[1]) + if length(ex.values) == 1 + print(io, " is") + elseif length(ex.values) == 2 + print(io, " and one more ") + print(io, nameof(typeof(ex.values[2]))) + print(io, " are") + else + print(io, " and ", length(ex.values) - 1, " more objects are") + end + print(io, " registered after the end of a `@sync` block") +end + function sync_end(c::Channel{Any}) local c_ex while isready(c) @@ -377,6 +401,25 @@ function sync_end(c::Channel{Any}) end end close(c) + + # Capture all waitable objects scheduled after the end of `@sync` and + # include them in the exception. This way, the user can check what was + # scheduled by examining at the exception object. + local racy + for r in c + if !@isdefined(racy) + racy = [] + end + push!(racy, r) + end + if @isdefined(racy) + if !@isdefined(c_ex) + c_ex = CompositeException() + end + # Since this is a clear programming error, show this exception first: + pushfirst!(c_ex, ScheduledAfterSyncException(racy)) + end + if @isdefined(c_ex) throw(c_ex) end diff --git a/test/errorshow.jl b/test/errorshow.jl index 32b7c417a5909..9572ccc4af224 100644 --- a/test/errorshow.jl +++ b/test/errorshow.jl @@ -805,6 +805,22 @@ if Sys.isapple() || (Sys.islinux() && Sys.ARCH === :x86_64) end end # Sys.isapple() +@testset "ScheduledAfterSyncException" begin + t = :DummyTask + msg = sprint(showerror, Base.ScheduledAfterSyncException(Any[t])) + @test occursin(":DummyTask is registered after the end of a `@sync` block", msg) + msg = sprint(showerror, Base.ScheduledAfterSyncException(Any[t, t])) + @test occursin( + ":DummyTask and one more Symbol are registered after the end of a `@sync` block", + msg, + ) + msg = sprint(showerror, Base.ScheduledAfterSyncException(Any[t, t, t])) + @test occursin( + ":DummyTask and 2 more objects are registered after the end of a `@sync` block", + msg, + ) +end + @testset "error message hints relative modules #40959" begin m = Module() expr = :(module Foo diff --git a/test/threads_exec.jl b/test/threads_exec.jl index b4c28d20b89cd..cba79d807e6f1 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -913,6 +913,67 @@ end end end +# @spawn racying with sync_end + +hidden_spawn(f) = Threads.@spawn f() + +function sync_end_race() + y = Ref(:notset) + local t + @sync begin + for _ in 1:6 # tweaked to maximize `nerror` below + Threads.@spawn nothing + end + t = hidden_spawn() do + Threads.@spawn y[] = :completed + end + end + try + wait(t) + catch + return :notscheduled + end + return y[] +end + +function check_sync_end_race() + @sync begin + done = Threads.Atomic{Bool}(false) + try + # `Threads.@spawn` must fail to be scheduled or complete its execution: + ncompleted = 0 + nnotscheduled = 0 + nerror = 0 + for i in 1:1000 + y = try + yield() + sync_end_race() + catch err + if err isa CompositeException + if err.exceptions[1] isa Base.ScheduledAfterSyncException + nerror += 1 + continue + end + end + rethrow() + end + y in (:completed, :notscheduled) || return (; i, y) + ncompleted += y === :completed + nnotscheduled += y === :notscheduled + end + # Useful for tuning the test: + @debug "`check_sync_end_race` done" nthreads() ncompleted nnotscheduled nerror + finally + done[] = true + end + end + return nothing +end + +@testset "Racy `@spawn`" begin + @test check_sync_end_race() === nothing +end + # issue #41546, thread-safe package loading @testset "package loading" begin ch = Channel{Bool}(nthreads())