Skip to content

Commit

Permalink
Stash removing useless count for pledges #93. Showstopper early resol…
Browse files Browse the repository at this point in the history
…ution bug nim-lang/Nim#8677 (static sandwich? nim-lang/Nim#11225)
  • Loading branch information
mratsim committed Jan 12, 2020
1 parent a6e06bf commit 2ecc4b1
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 38 deletions.
61 changes: 36 additions & 25 deletions weave/channels/channels_mpsc_unbounded_batch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import
# - https://github.com/nim-lang/Nim/issues/12714
# - https://github.com/nim-lang/Nim/issues/13048

macro derefMPSC*(T: typedesc): typedesc =
# This somehows isn't bound properly when used in a typesection
macro derefMPSC*(keepCount: static bool, T: typedesc): typedesc =
# This somehows isn't bound properly when used in a typesection (it needs to be exported in this module)
# and it's even worse if that type section involves a static (it needs to be reexported without hops to all modules!)
let instantiated = T.getTypeInst
instantiated.expectkind(nnkBracketExpr)
doAssert instantiated[0].eqIdent"typeDesc"
Expand Down Expand Up @@ -44,7 +45,7 @@ type
x.next is Atomic[pointer]
# Workaround generic atomics bug: https://github.com/nim-lang/Nim/issues/12695

ChannelMpscUnboundedBatch*[T: Enqueueable] = object
ChannelMpscUnboundedBatch*[keepCount: static bool, T: Enqueueable] = object
## Lockless multi-producer single-consumer channel
##
## Properties:
Expand All @@ -64,7 +65,7 @@ type

# Producers and consumer slow-path
back{.align: MpscPadding.}: Atomic[pointer] # Workaround generic atomics bug: https://github.com/nim-lang/Nim/issues/12695
# Accessed by all
# Accessed by all - don't remove it even when not keeping count. This pads back/front by 2x cache lines.
count{.align: MpscPadding.}: Atomic[int]
# Consumer only - front is a dummy node
front{.align: MpscPadding.}: derefMPSC(T)
Expand All @@ -87,12 +88,13 @@ type
# Implementation
# --------------------------------------------------------------

proc initialize*[T](chan: var ChannelMpscUnboundedBatch[T]) {.inline.}=
proc initialize*[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T]) {.inline.}=
chan.front.next.store(nil, moRelaxed)
chan.back.store(chan.front.addr, moRelaxed)
chan.count.store(0, moRelaxed)
when keepCount:
chan.count.store(0, moRelaxed)

proc trySend*[T](chan: var ChannelMpscUnboundedBatch[T], src: sink T): bool {.inline.}=
proc trySend*[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T], src: sink T): bool {.inline.}=
## Send an item to the back of the channel
## As the channel has unbounded capacity, this should never fail

Expand All @@ -106,12 +108,13 @@ proc trySend*[T](chan: var ChannelMpscUnboundedBatch[T], src: sink T): bool {.in

return true

proc trySendBatch*[T](chan: var ChannelMpscUnboundedBatch[T], first, last: sink T, count: SomeInteger): bool {.inline.}=
proc trySendBatch*[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T], first, last: sink T, count: SomeInteger): bool {.inline.}=
## Send a list of items to the back of the channel
## They should be linked together by their next field
## As the channel has unbounded capacity this should never fail

discard chan.count.fetchAdd(int(count), moRelaxed)
when keepCount:
discard chan.count.fetchAdd(int(count), moRelaxed)
last.next.store(nil, moRelaxed)
fence(moRelease)
let oldBack = chan.back.exchange(last, moRelaxed)
Expand All @@ -121,7 +124,7 @@ proc trySendBatch*[T](chan: var ChannelMpscUnboundedBatch[T], first, last: sink

return true

proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
proc tryRecv*[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T], dst: var T): bool =
## Try receiving the next item buffered in the channel
## Returns true if successful (channel was not empty)
## This can fail spuriously on the last element if producer
Expand All @@ -137,7 +140,8 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
if not next.isNil:
# Not competing with producers
prefetch(first)
discard chan.count.fetchSub(1, moRelaxed)
when keepCount:
discard chan.count.fetchSub(1, moRelaxed)
chan.front.next.store(next, moRelaxed)
fence(moAcquire)
dst = first
Expand All @@ -153,7 +157,8 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
chan.front.next.store(nil, moRelaxed)
if compareExchange(chan.back, last, chan.front.addr, moAcquireRelease):
# We won and replaced the last node with the channel front
discard chan.count.fetchSub(1, moRelaxed)
when keepCount:
discard chan.count.fetchSub(1, moRelaxed)
dst = first
return true

Expand All @@ -170,13 +175,14 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
next = first.next.load(moRelaxed)

prefetch(first)
discard chan.count.fetchSub(1, moRelaxed)
when keepCount:
discard chan.count.fetchSub(1, moRelaxed)
chan.front.next.store(next, moRelaxed)
fence(moAcquire)
dst = first
return true

proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var T): int32 =
proc tryRecvBatch*[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T], bFirst, bLast: var T): int32 =
## Try receiving all items buffered in the channel
## Returns true if at least some items are dequeued.
## There might be competition with producers for the last item
Expand Down Expand Up @@ -210,8 +216,9 @@ proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var
if front != last:
# We lose the competition, bail out
chan.front.next.store(front, moRelaxed)
discard chan.count.fetchSub(result, moRelaxed)
postCondition: chan.count.load(moRelaxed) >= 0 # TODO: somehow it can be negative
when keepCount:
discard chan.count.fetchSub(result, moRelaxed)
postCondition: chan.count.load(moRelaxed) >= 0 # TODO: somehow it can be negative
return

# front == last
Expand All @@ -220,9 +227,10 @@ proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var
# We won and replaced the last node with the channel front
prefetch(front)
result += 1
discard chan.count.fetchSub(result, moRelaxed)
bLast = front
postCondition: chan.count.load(moRelaxed) >= 0
when keepCount:
discard chan.count.fetchSub(result, moRelaxed)
postCondition: chan.count.load(moRelaxed) >= 0
return

# We lost but now we know that there is an extra node
Expand All @@ -242,11 +250,13 @@ proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var

prefetch(front)
result += 1
discard chan.count.fetchSub(result, moRelaxed)
chan.front.next.store(next, moRelaxed)
fence(moAcquire)
bLast = front
postCondition: chan.count.load(moRelaxed) >= 0

when keepCount:
discard chan.count.fetchSub(result, moRelaxed)
postCondition: chan.count.load(moRelaxed) >= 0

func peek*(chan: var ChannelMpscUnboundedBatch): int32 {.inline.} =
## Estimates the number of items pending in the channel
Expand All @@ -257,6 +267,7 @@ func peek*(chan: var ChannelMpscUnboundedBatch): int32 {.inline.} =
## the consumer removes them concurrently.
##
## This is a non-locking operation.
static: doAssert chan.keepCount
result = int32 chan.count.load(moAcquire)

# For the consumer it's always positive or zero
Expand All @@ -278,13 +289,13 @@ when isMainModule:
when not compileOption("threads"):
{.error: "This requires --threads:on compilation flag".}

template sendLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template sendLoop[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T],
data: sink T,
body: untyped): untyped =
while not chan.trySend(data):
body

template recvLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template recvLoop[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T],
data: var T,
body: untyped): untyped =
while not chan.tryRecv(data):
Expand Down Expand Up @@ -319,7 +330,7 @@ when isMainModule:

ThreadArgs = object
ID: WorkerKind
chan: ptr ChannelMpscUnboundedBatch[Val]
chan: ptr ChannelMpscUnboundedBatch[true, Val]

template Worker(id: WorkerKind, body: untyped): untyped {.dirty.} =
if args.ID == id:
Expand Down Expand Up @@ -374,7 +385,7 @@ when isMainModule:
echo "Testing if 15 threads can send data to 1 consumer"
echo "------------------------------------------------------------------------"
var threads: array[WorkerKind, Thread[ThreadArgs]]
let chan = createSharedU(ChannelMpscUnboundedBatch[Val]) # CreateU is not zero-init
let chan = createSharedU(ChannelMpscUnboundedBatch[true, Val]) # CreateU is not zero-init
chan[].initialize()

createThread(threads[Receiver], thread_func_receiver, ThreadArgs(ID: Receiver, chan: chan))
Expand Down Expand Up @@ -425,7 +436,7 @@ when isMainModule:
echo "Testing if 15 threads can send data to 1 consumer with batch receive"
echo "------------------------------------------------------------------------"
var threads: array[WorkerKind, Thread[ThreadArgs]]
let chan = createSharedU(ChannelMpscUnboundedBatch[Val]) # CreateU is not zero-init
let chan = createSharedU(ChannelMpscUnboundedBatch[true, Val]) # CreateU is not zero-init
chan[].initialize()

# log("Channel address 0x%.08x (dummy 0x%.08x)\n", chan, chan.front.addr)
Expand Down
6 changes: 3 additions & 3 deletions weave/channels/pledges.nim
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ type
# The MPSC Channel is intrusive to the PledgeImpl.
# The end fields in the channel should be the consumer
# to avoid cache-line conflicts with producer threads.
chan{.align: WV_CacheLinePadding div 2.}: ChannelMpscUnboundedBatch[TaskNode]
chan{.align: WV_CacheLinePadding div 2.}: ChannelMpscUnboundedBatch[false, TaskNode]
deferredIn: Atomic[int32]
deferredOut: Atomic[int32]
fulfilled: Atomic[bool]
Expand Down Expand Up @@ -513,8 +513,8 @@ macro delayedUntilMulti*(task: Task, pool: var TLPoolAllocator, pledges: varargs
# TODO: Once upstream fixes https://github.com/nim-lang/Nim/issues/13122
# the size here will be wrong

assert sizeof(ChannelMpscUnboundedBatch[TaskNode]) == 56, # Upstream {.align.} bug
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[TaskNode])
assert sizeof(ChannelMpscUnboundedBatch[false, TaskNode]) == 56, # Upstream {.align.} bug
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[false, TaskNode])

assert sizeof(PledgeImpl) == 128,
"PledgeImpl size was " & $sizeof(PledgeImpl)
Expand Down
6 changes: 4 additions & 2 deletions weave/contexts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import
./config,
./instrumentation/[profilers, loggers, contracts]

export derefMPSC # Need to be reexported due to a static early resolution bug :/.

when defined(WV_metrics):
import system/ansi_c, ./primitives/barriers

Expand Down Expand Up @@ -49,10 +51,10 @@ template isRootTask*(task: Task): bool =
template myTodoBoxes*: Persistack[WV_MaxConcurrentStealPerWorker, ChannelSpscSinglePtr[Task]] =
globalCtx.com.tasks[localCtx.worker.ID]

template myThieves*: ChannelMpscUnboundedBatch[StealRequest] =
template myThieves*: ChannelMpscUnboundedBatch[true, StealRequest] =
globalCtx.com.thefts[localCtx.worker.ID]

template getThievesOf*(worker: WorkerID): ChannelMpscUnboundedBatch[StealRequest] =
template getThievesOf*(worker: WorkerID): ChannelMpscUnboundedBatch[true, StealRequest] =
globalCtx.com.thefts[worker]

template myMemPool*: TLPoolAllocator =
Expand Down
2 changes: 1 addition & 1 deletion weave/datatypes/context_global.nim
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type
# per channel and a known max number of workers

# Theft channels are bounded to "NumWorkers * WV_MaxConcurrentStealPerWorker"
thefts*: ptr UncheckedArray[ChannelMpscUnboundedBatch[StealRequest]]
thefts*: ptr UncheckedArray[ChannelMpscUnboundedBatch[true, StealRequest]]
tasks*: ptr UncheckedArray[Persistack[WV_MaxConcurrentStealPerWorker, ChannelSpscSinglePtr[Task]]]
when static(WV_Backoff):
parking*: ptr UncheckedArray[EventNotifier]
Expand Down
16 changes: 9 additions & 7 deletions weave/memory/memory_pools.nim
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import
./allocs, ./thread_id,
std/atomics

export derefMPSC # Need to be reexported due to a static early resolution bug :/.

# Constants (move in config.nim)
# ----------------------------------------------------------------------------------

Expand Down Expand Up @@ -112,7 +114,7 @@ type
# ⚠️ Consumer thread field must be at the end
# to prevent cache-line contention
# and save on space (no padding on the next field)
remoteFree {.align: WV_CacheLinePadding.}: ChannelMpscUnboundedBatch[ptr MemBlock]
remoteFree {.align: WV_CacheLinePadding.}: ChannelMpscUnboundedBatch[true, ptr MemBlock]
# Freed blocks, kept separately to deterministically trigger slow path
# after an amortized amount of allocation
localFree: ptr MemBlock
Expand Down Expand Up @@ -623,8 +625,8 @@ proc takeover*(pool: var TLPoolAllocator, target: sink TLPoolAllocator) =
# TODO: Once upstream fixes https://github.com/nim-lang/Nim/issues/13122
# the size here will likely be wrong

assert sizeof(ChannelMpscUnboundedBatch[ptr MemBlock]) == 272,
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[ptr MemBlock])
assert sizeof(ChannelMpscUnboundedBatch[true, ptr MemBlock]) == 272,
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[true, ptr MemBlock])

assert sizeof(Arena) == WV_MemArenaSize,
"The real arena size was " & $sizeof(Arena) &
Expand Down Expand Up @@ -688,13 +690,13 @@ when isMainModule:
when not compileOption("threads"):
{.error: "This requires --threads:on compilation flag".}

template sendLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template sendLoop[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T],
data: sink T,
body: untyped): untyped =
while not chan.trySend(data):
body

template recvLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template recvLoop[keepCount: static bool, T](chan: var ChannelMpscUnboundedBatch[keepCount, T],
data: var T,
body: untyped): untyped =
while not chan.tryRecv(data):
Expand Down Expand Up @@ -726,7 +728,7 @@ when isMainModule:

ThreadArgs = object
ID: WorkerKind
chan: ptr ChannelMpscUnboundedBatch[Val]
chan: ptr ChannelMpscUnboundedBatch[true, Val]
pool: ptr TLPoolAllocator

AllocKind = enum
Expand Down Expand Up @@ -808,7 +810,7 @@ when isMainModule:
var threads: array[WorkerKind, Thread[ThreadArgs]]
var pools: ptr array[WorkerKind, TLPoolAllocator]

let chan = createSharedU(ChannelMpscUnboundedBatch[Val])
let chan = createSharedU(ChannelMpscUnboundedBatch[true, Val])
chan[].initialize()

pools = cast[typeof pools](createSharedU(TLPoolAllocator, pools[].len))
Expand Down

0 comments on commit 2ecc4b1

Please sign in to comment.