diff --git a/Sia/Entities/Extensions/EntityQueryParallelExtensions.cs b/Sia/Entities/Extensions/EntityQueryParallelExtensions.cs index 124b5d3..2c8afc3 100644 --- a/Sia/Entities/Extensions/EntityQueryParallelExtensions.cs +++ b/Sia/Entities/Extensions/EntityQueryParallelExtensions.cs @@ -6,6 +6,9 @@ namespace Sia; public static class EntityQueryParallelExtensions { + public delegate void RecordFunc(in EntityRef entity, ref TResult result); + public delegate void RecordFunc(in TData data, in EntityRef entity, ref TResult result); + private readonly struct ForEachParallelAction { private readonly ArraySegment _array; @@ -46,40 +49,46 @@ public void Invoke(System.Tuple range) } } - private unsafe struct ForEachParallelData + private readonly unsafe struct RecordData { - public int* Index; - public ArraySegment Array; + public readonly int* Index; + public readonly ArraySegment Array; - public ForEachParallelData(int* index, ArraySegment array) + public RecordData(int* index, ArraySegment array) { Index = index; Array = array; } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private unsafe static void DoRecord(IEntityQuery query, ArraySegment array) + private readonly unsafe struct RecordData { - int index = 0; - var indexPtr = (int*)Unsafe.AsPointer(ref index); - - query.ForEach(new(indexPtr, array), static (in ForEachParallelData data, in EntityRef entity) => { - ref int index = ref *data.Index; - data.Array[index] = entity; - ++index; - }); + public readonly int* Index; + public readonly ArraySegment Array; + public readonly RecordFunc RecordFunc; + + public RecordData(int* index, ArraySegment array, RecordFunc recordFunc) + { + Index = index; + Array = array; + RecordFunc = recordFunc; + } } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe SpanOwner Record(this IEntityQuery query) + private readonly unsafe struct RecordData { - var count = query.Count; - if (count == 0) { return default; } - - var spanOwner = SpanOwner.Allocate(count); - DoRecord(query, spanOwner.DangerousGetArray()); - return spanOwner; + public readonly int* Index; + public readonly TData Data; + public readonly ArraySegment Array; + public readonly RecordFunc RecordFunc; + + public RecordData(int* index, in TData data, ArraySegment array, RecordFunc recordFunc) + { + Index = index; + Data = data; + Array = array; + RecordFunc = recordFunc; + } } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -90,7 +99,7 @@ public static unsafe void ForEachParallel(this IEntityQuery query, EntityHandler var spanOwner = SpanOwner.Allocate(count); var array = spanOwner.DangerousGetArray(); - DoRecord(query, array); + RecordEntities(query, array); var action = new ForEachParallelAction(array, handler); Partitioner.Create(0, count) @@ -106,7 +115,7 @@ public static unsafe void ForEachParallel(this IEntityQuery query, in TDa var spanOwner = SpanOwner.Allocate(count); var array = spanOwner.DangerousGetArray(); - DoRecord(query, array); + RecordEntities(query, array); var action = new ForEachParallelAction(array, data, handler); Partitioner.Create(0, count) @@ -123,4 +132,72 @@ public static unsafe void ForEachParallel(this IEntityQuery query, SimpleEntityH public static unsafe void ForEachParallel(this IEntityQuery query, in TData data, SimpleEntityHandler handler) => ForEachParallel(query, (handler, data), static (in (SimpleEntityHandler, TData) data, in EntityRef entity) => data.Item1(data.Item2, entity)); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe SpanOwner Record(this IEntityQuery query) + { + var count = query.Count; + if (count == 0) { return default; } + + var spanOwner = SpanOwner.Allocate(count); + RecordEntities(query, spanOwner.DangerousGetArray()); + return spanOwner; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe SpanOwner Record( + this IEntityQuery query, RecordFunc recordFunc) + { + var count = query.Count; + if (count == 0) { return default; } + + var spanOwner = SpanOwner.Allocate(count); + var array = spanOwner.DangerousGetArray(); + + int index = 0; + var indexPtr = (int*)Unsafe.AsPointer(ref index); + + query.ForEach(new(indexPtr, array, recordFunc), static (in RecordData data, in EntityRef entity) => { + ref int index = ref *data.Index; + data.RecordFunc(entity, ref data.Array.AsSpan()[index]); + ++index; + }); + + return spanOwner; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static unsafe SpanOwner Record( + this IEntityQuery query, in TData data, RecordFunc recordFunc) + { + var count = query.Count; + if (count == 0) { return default; } + + var spanOwner = SpanOwner.Allocate(count); + var array = spanOwner.DangerousGetArray(); + + int index = 0; + var indexPtr = (int*)Unsafe.AsPointer(ref index); + + query.ForEach(new(indexPtr, data, array, recordFunc), static (in RecordData data, in EntityRef entity) => { + ref int index = ref *data.Index; + data.RecordFunc(data.Data, entity, ref data.Array.AsSpan()[index]); + ++index; + }); + + return spanOwner; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe static void RecordEntities(IEntityQuery query, ArraySegment array) + { + int index = 0; + var indexPtr = (int*)Unsafe.AsPointer(ref index); + + query.ForEach(new(indexPtr, array), static (in RecordData data, in EntityRef entity) => { + ref int index = ref *data.Index; + data.Array[index] = entity; + ++index; + }); + } } \ No newline at end of file