Skip to content

Commit

Permalink
Make MPSC count optional, finishes #93
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed May 16, 2020
1 parent 9f25999 commit 76cc7fd
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 60 deletions.
4 changes: 2 additions & 2 deletions weave/contexts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ template myTodoBoxes*: Persistack[WV_MaxConcurrentStealPerWorker, ChannelSpscSin
template managerJobQueue*: ChannelMpscUnboundedBatch[Job] =
globalCtx.manager.jobsIncoming[]

template myThieves*: ChannelMpscUnboundedBatch[StealRequest] =
template myThieves*: ChannelMpscUnboundedBatch[StealRequest, keepCount = true] =
globalCtx.com.thefts[workerContext.worker.ID]

template getThievesOf*(worker: WorkerID): ChannelMpscUnboundedBatch[StealRequest] =
template getThievesOf*(worker: WorkerID): ChannelMpscUnboundedBatch[StealRequest, keepCount = true] =
globalCtx.com.thefts[worker]

template myMemPool*: TLPoolAllocator =
Expand Down
84 changes: 38 additions & 46 deletions weave/cross_thread_com/channels_mpsc_unbounded_batch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@ import
# - https://github.com/nim-lang/Nim/issues/12714
# - https://github.com/nim-lang/Nim/issues/13048

macro derefMPSC*(T: typedesc): typedesc =
# This somehows isn't bound properly when used in a typesection
let instantiated = T.getTypeInst
instantiated.expectkind(nnkBracketExpr)
doAssert instantiated[0].eqIdent"typeDesc"

let ptrT = instantiated[1]
if ptrT.kind == nnkPtrTy:
return ptrT[0]

let ptrTImpl = instantiated[1].getImpl
ptrTimpl.expectKind(nnkTypeDef)
ptrTImpl[2].expectKind(nnkPtrTy)
ptrTImpl[2][0].expectKind({nnkObjectTy, nnkSym})

return ptrTImpl[2][0]

# MPSC channel
# ------------------------------------------------

Expand All @@ -47,7 +30,7 @@ type
x.next is Atomic[pointer]
# Workaround generic atomics bug: https://github.com/nim-lang/Nim/issues/12695

ChannelMpscUnboundedBatch*[T: Enqueueable] = object
ChannelMpscUnboundedBatch*[T: Enqueueable, keepCount: static bool] = object
## Lockless multi-producer single-consumer channel
##
## Properties:
Expand All @@ -70,7 +53,7 @@ type
# Accessed by all
count: Atomic[int]
# Consumer only - front is a dummy node
front{.align: MpscPadding.}: derefMPSC(T)
front{.align: MpscPadding.}: typeof(default(T)[])
# back and front order is chosen so that the data structure can be
# made intrusive to consumer data-structures
# like the memory-pool and the pledges so that
Expand All @@ -93,17 +76,19 @@ type
# Implementation
# --------------------------------------------------------------

proc initialize*[T](chan: var ChannelMpscUnboundedBatch[T]) {.inline.}=
proc initialize*[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount]) {.inline.}=
chan.front.next.store(nil, moRelaxed)
chan.back.store(chan.front.addr, moRelaxed)
chan.count.store(0, moRelaxed)
when keepCount:
chan.count.store(0, moRelaxed)

proc trySend*[T](chan: var ChannelMpscUnboundedBatch[T], src: sink T): bool {.inline.}=
proc trySend*[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount], src: sink T): bool {.inline.}=
## Send an item to the back of the channel
## As the channel has unbounded capacity, this should never fail

let oldCount {.used.} = chan.count.fetchAdd(1, moRelease)
postCondition: oldCount >= 0
when keepCount:
let oldCount {.used.} = chan.count.fetchAdd(1, moRelease)
ascertain: oldCount >= 0

src.next.store(nil, moRelease)
let oldBack = chan.back.exchange(src, moAcquireRelease)
Expand All @@ -113,13 +98,14 @@ proc trySend*[T](chan: var ChannelMpscUnboundedBatch[T], src: sink T): bool {.in

return true

proc trySendBatch*[T](chan: var ChannelMpscUnboundedBatch[T], first, last: sink T, count: SomeInteger): bool {.inline.}=
proc trySendBatch*[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount], first, last: sink T, count: SomeInteger): bool {.inline.}=
## Send a list of items to the back of the channel
## They should be linked together by their next field
## As the channel has unbounded capacity this should never fail

let oldCount {.used.} = chan.count.fetchAdd(1, moRelease)
postCondition: oldCount >= 0
when keepCount:
let oldCount {.used.} = chan.count.fetchAdd(int(count), moRelease)
ascertain: oldCount >= 0

last.next.store(nil, moRelease)
let oldBack = chan.back.exchange(last, moAcquireRelease)
Expand All @@ -129,7 +115,7 @@ proc trySendBatch*[T](chan: var ChannelMpscUnboundedBatch[T], first, last: sink

return true

proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
proc tryRecv*[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount], dst: var T): bool =
## Try receiving the next item buffered in the channel
## Returns true if successful (channel was not empty)
## This can fail spuriously on the last element if producer
Expand All @@ -153,8 +139,9 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
# fence(moAcquire) # Sync "first.next.load(moRelaxed)"
dst = first

let oldCount {.used.} = chan.count.fetchSub(1, moRelaxed)
ascertain: oldCount >= 1 # The producers may overestimate the count
when keepCount:
let oldCount {.used.} = chan.count.fetchSub(1, moRelaxed)
ascertain: oldCount >= 1 # The producers may overestimate the count
return true
# End fast-path

Expand All @@ -171,8 +158,9 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
prefetch(first)
dst = first

let oldCount {.used.} = chan.count.fetchSub(1, moRelaxed)
ascertain: oldCount >= 1 # The producers may overestimate the count
when keepCount:
let oldCount {.used.} = chan.count.fetchSub(1, moRelaxed)
ascertain: oldCount >= 1 # The producers may overestimate the count
return true

# We lost but now we know that there is an extra node coming very soon
Expand All @@ -192,8 +180,9 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
# fence(moAcquire) # sync first.next.load(moRelaxed)
dst = first

let oldCount {.used.} = chan.count.fetchSub(1, moRelaxed)
ascertain: oldCount >= 1 # The producers may overestimate the count
when keepCount:
let oldCount {.used.} = chan.count.fetchSub(1, moRelaxed)
ascertain: oldCount >= 1 # The producers may overestimate the count
return true

# # Alternative implementation
Expand All @@ -212,7 +201,7 @@ proc tryRecv*[T](chan: var ChannelMpscUnboundedBatch[T], dst: var T): bool =
# # The last item wasn't linked to the list yet, bail out
# return false

proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var T): int32 =
proc tryRecvBatch*[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount], bFirst, bLast: var T): int32 =
## Try receiving all items buffered in the channel
## Returns true if at least some items are dequeued.
## There might be competition with producers for the last item
Expand Down Expand Up @@ -247,8 +236,9 @@ proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var
# We lose the competition, bail out
chan.front.next.store(front, moRelease)

let oldCount {.used.} = chan.count.fetchSub(result, moRelaxed)
postCondition: oldCount >= result # TODO: somehow it can be negative
when keepCount:
let oldCount {.used.} = chan.count.fetchSub(result, moRelaxed)
postCondition: oldCount >= result # TODO: somehow it can be negative
return

# front == last
Expand All @@ -259,8 +249,9 @@ proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var
result += 1
bLast = front

let oldCount {.used.} = chan.count.fetchSub(result, moRelaxed)
postCondition: oldCount >= result # TODO: somehow it can be negative
when keepCount:
let oldCount {.used.} = chan.count.fetchSub(result, moRelaxed)
postCondition: oldCount >= result # TODO: somehow it can be negative
return

# We lost but now we know that there is an extra node
Expand All @@ -284,8 +275,9 @@ proc tryRecvBatch*[T](chan: var ChannelMpscUnboundedBatch[T], bFirst, bLast: var
# fence(moAcquire) # sync front.next.load(moRelaxed)
bLast = front

let oldCount {.used.} = chan.count.fetchSub(result, moRelaxed)
postCondition: oldCount >= result # TODO: somehow it can be negative
when keepCount:
let oldCount {.used.} = chan.count.fetchSub(result, moRelaxed)
postCondition: oldCount >= result # TODO: somehow it can be negative

func peek*(chan: var ChannelMpscUnboundedBatch): int32 {.inline.} =
## Estimates the number of items pending in the channel
Expand Down Expand Up @@ -317,13 +309,13 @@ when isMainModule:
when not compileOption("threads"):
{.error: "This requires --threads:on compilation flag".}

template sendLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template sendLoop[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount],
data: sink T,
body: untyped): untyped =
while not chan.trySend(data):
body

template recvLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template recvLoop[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount],
data: var T,
body: untyped): untyped =
while not chan.tryRecv(data):
Expand Down Expand Up @@ -358,7 +350,7 @@ when isMainModule:

ThreadArgs = object
ID: WorkerKind
chan: ptr ChannelMpscUnboundedBatch[Val]
chan: ptr ChannelMpscUnboundedBatch[Val, true]

template Worker(id: WorkerKind, body: untyped): untyped {.dirty.} =
if args.ID == id:
Expand Down Expand Up @@ -413,7 +405,7 @@ when isMainModule:
echo "Testing if 15 threads can send data to 1 consumer"
echo "------------------------------------------------------------------------"
var threads: array[WorkerKind, Thread[ThreadArgs]]
let chan = createSharedU(ChannelMpscUnboundedBatch[Val]) # CreateU is not zero-init
let chan = createSharedU(ChannelMpscUnboundedBatch[Val, true]) # CreateU is not zero-init
chan[].initialize()

createThread(threads[Receiver], thread_func_receiver, ThreadArgs(ID: Receiver, chan: chan))
Expand Down Expand Up @@ -464,7 +456,7 @@ when isMainModule:
echo "Testing if 15 threads can send data to 1 consumer with batch receive"
echo "------------------------------------------------------------------------"
var threads: array[WorkerKind, Thread[ThreadArgs]]
let chan = createSharedU(ChannelMpscUnboundedBatch[Val]) # CreateU is not zero-init
let chan = createSharedU(ChannelMpscUnboundedBatch[Val, true]) # CreateU is not zero-init
chan[].initialize()

# log("Channel address 0x%.08x (dummy 0x%.08x)\n", chan, chan.front.addr)
Expand Down
6 changes: 3 additions & 3 deletions weave/cross_thread_com/pledges.nim
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ type
# The MPSC Channel is intrusive to the PledgeImpl.
# The end fields in the channel should be the consumer
# to avoid cache-line conflicts with producer threads.
chan: ChannelMpscUnboundedBatch[TaskNode]
chan: ChannelMpscUnboundedBatch[TaskNode, keepCount = false]
deferredIn: Atomic[int32]
deferredOut: Atomic[int32]
fulfilled: Atomic[bool]
Expand Down Expand Up @@ -535,8 +535,8 @@ debugSizeAsserts:
doAssert sizeof(default(TaskNode)[]) == expectedSize,
"TaskNode size was " & $sizeof(default(TaskNode)[])

doAssert sizeof(ChannelMpscUnboundedBatch[TaskNode]) == 128,
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[TaskNode])
doAssert sizeof(ChannelMpscUnboundedBatch[TaskNode, false]) == 128,
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[TaskNode, false])

doAssert sizeof(PledgeImpl) == 192,
"PledgeImpl size was " & $sizeof(PledgeImpl)
Expand Down
2 changes: 1 addition & 1 deletion weave/datatypes/context_global.nim
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ type
# per channel and a known max number of workers

# Theft channels are bounded to "NumWorkers * WV_MaxConcurrentStealPerWorker"
thefts*: ptr UncheckedArray[ChannelMpscUnboundedBatch[StealRequest]]
tasksStolen*: ptr UncheckedArray[Persistack[WV_MaxConcurrentStealPerWorker, ChannelSpscSinglePtr[Task]]]
thefts*: ptr UncheckedArray[ChannelMpscUnboundedBatch[StealRequest, keepCount = true]]
when static(WV_Backoff):
parking*: ptr UncheckedArray[EventNotifier]

Expand Down
14 changes: 7 additions & 7 deletions weave/memory/memory_pools.nim
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ type
# ⚠️ Consumer thread field must be at the end
# to prevent cache-line contention
# and save on space (no padding on the next field)
remoteFree: ChannelMpscUnboundedBatch[ptr MemBlock]
remoteFree: ChannelMpscUnboundedBatch[ptr MemBlock, keepCount = true]
# Freed blocks, kept separately to deterministically trigger slow path
# after an amortized amount of allocation
localFree: ptr MemBlock
Expand Down Expand Up @@ -624,8 +624,8 @@ proc takeover*(pool: var TLPoolAllocator, target: sink TLPoolAllocator) =
# the size here will likely be wrong

debugSizeAsserts:
doAssert sizeof(ChannelMpscUnboundedBatch[ptr MemBlock]) == 320,
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[ptr MemBlock])
doAssert sizeof(ChannelMpscUnboundedBatch[ptr MemBlock, keepCount = true]) == 320,
"MPSC channel size was " & $sizeof(ChannelMpscUnboundedBatch[ptr MemBlock, keepCount = true])

doAssert sizeof(Arena) == WV_MemArenaSize,
"The real arena size was " & $sizeof(Arena) &
Expand Down Expand Up @@ -689,13 +689,13 @@ when isMainModule:
when not compileOption("threads"):
{.error: "This requires --threads:on compilation flag".}

template sendLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template sendLoop[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount],
data: sink T,
body: untyped): untyped =
while not chan.trySend(data):
body

template recvLoop[T](chan: var ChannelMpscUnboundedBatch[T],
template recvLoop[T, keepCount](chan: var ChannelMpscUnboundedBatch[T, keepCount],
data: var T,
body: untyped): untyped =
while not chan.tryRecv(data):
Expand Down Expand Up @@ -727,7 +727,7 @@ when isMainModule:

ThreadArgs = object
ID: WorkerKind
chan: ptr ChannelMpscUnboundedBatch[Val]
chan: ptr ChannelMpscUnboundedBatch[Val, true]
pool: ptr TLPoolAllocator

AllocKind = enum
Expand Down Expand Up @@ -809,7 +809,7 @@ when isMainModule:
var threads: array[WorkerKind, Thread[ThreadArgs]]
var pools: ptr array[WorkerKind, TLPoolAllocator]

let chan = createSharedU(ChannelMpscUnboundedBatch[Val])
let chan = createSharedU(ChannelMpscUnboundedBatch[Val, true])
chan[].initialize()

pools = cast[typeof pools](createSharedU(TLPoolAllocator, pools[].len))
Expand Down
2 changes: 1 addition & 1 deletion weave/runtime.nim
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ proc init*(_: type Weave) =
## Allocation of the global context.
globalCtx.mempools = wv_alloc(TLPoolAllocator, workforce())
globalCtx.threadpool = wv_alloc(Thread[WorkerID], workforce())
globalCtx.com.thefts = wv_alloc(ChannelMpscUnboundedBatch[StealRequest], workforce())
globalCtx.com.tasksStolen = wv_alloc(Persistack[WV_MaxConcurrentStealPerWorker, ChannelSpscSinglePtr[Task]], workforce())
globalCtx.com.thefts = wv_alloc(ChannelMpscUnboundedBatch[StealRequest, true], workforce())
Backoff:
globalCtx.com.parking = wv_alloc(EventNotifier, workforce())
globalCtx.barrier.init(workforce())
Expand Down

0 comments on commit 76cc7fd

Please sign in to comment.