Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/CommunityToolkit.HighPerformance/Memory/Memory2D{T}.cs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ public Memory2D(MemoryManager<T> memoryManager, int height, int width)
/// <exception cref="ArgumentException">
/// Thrown when the requested area is outside of bounds for <paramref name="memoryManager"/>.
/// </exception>
public Memory2D(MemoryManager<T> memoryManager, int offset, int height, int width, int pitch)
public unsafe Memory2D(MemoryManager<T> memoryManager, int offset, int height, int width, int pitch)
{
int length = memoryManager.GetSpan().Length;

Expand Down Expand Up @@ -378,7 +378,7 @@ public Memory2D(MemoryManager<T> memoryManager, int offset, int height, int widt
}

this.instance = memoryManager;
this.offset = (nint)(uint)offset;
this.offset = (nint)(uint)offset * (nint)(uint)sizeof(T);
this.height = height;
this.width = width;
this.pitch = pitch;
Expand Down Expand Up @@ -413,7 +413,7 @@ internal Memory2D(Memory<T> memory, int height, int width)
/// <exception cref="ArgumentException">
/// Thrown when the requested area is outside of bounds for <paramref name="memory"/>.
/// </exception>
internal Memory2D(Memory<T> memory, int offset, int height, int width, int pitch)
internal unsafe Memory2D(Memory<T> memory, int offset, int height, int width, int pitch)
{
if ((uint)offset > (uint)memory.Length)
{
Expand Down Expand Up @@ -477,7 +477,7 @@ internal Memory2D(Memory<T> memory, int offset, int height, int width, int pitch
else if (MemoryMarshal.TryGetMemoryManager<T, MemoryManager<T>>(memory, out MemoryManager<T>? memoryManager, out int memoryManagerStart, out _))
{
this.instance = memoryManager;
this.offset = (nint)(uint)(memoryManagerStart + offset);
this.offset = (nint)(uint)(memoryManagerStart + offset) * (nint)(uint)sizeof(T);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ public ReadOnlyMemory2D(MemoryManager<T> memoryManager, int height, int width)
/// <exception cref="ArgumentException">
/// Thrown when the requested area is outside of bounds for <paramref name="memoryManager"/>.
/// </exception>
public ReadOnlyMemory2D(MemoryManager<T> memoryManager, int offset, int height, int width, int pitch)
public unsafe ReadOnlyMemory2D(MemoryManager<T> memoryManager, int offset, int height, int width, int pitch)
{
int length = memoryManager.GetSpan().Length;

Expand Down Expand Up @@ -398,7 +398,7 @@ public ReadOnlyMemory2D(MemoryManager<T> memoryManager, int offset, int height,
}

this.instance = memoryManager;
this.offset = (nint)(uint)offset;
this.offset = (nint)(uint)offset * (nint)(uint)sizeof(T);
this.height = height;
this.width = width;
this.pitch = pitch;
Expand Down Expand Up @@ -433,7 +433,7 @@ internal ReadOnlyMemory2D(ReadOnlyMemory<T> memory, int height, int width)
/// <exception cref="ArgumentException">
/// Thrown when the requested area is outside of bounds for <paramref name="memory"/>.
/// </exception>
internal ReadOnlyMemory2D(ReadOnlyMemory<T> memory, int offset, int height, int width, int pitch)
internal unsafe ReadOnlyMemory2D(ReadOnlyMemory<T> memory, int offset, int height, int width, int pitch)
{
if ((uint)offset > (uint)memory.Length)
{
Expand Down Expand Up @@ -489,7 +489,7 @@ internal ReadOnlyMemory2D(ReadOnlyMemory<T> memory, int offset, int height, int
else if (MemoryMarshal.TryGetMemoryManager(memory, out MemoryManager<T>? memoryManager, out int memoryManagerStart, out _))
{
this.instance = memoryManager;
this.offset = (nint)(uint)(memoryManagerStart + offset);
this.offset = (nint)(uint)(memoryManagerStart + offset) * (nint)(uint)sizeof(T);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
#if NET6_0_OR_GREATER
using System.Buffers;
#endif
using System.Runtime.CompilerServices;
using CommunityToolkit.HighPerformance.Enumerables;
using Microsoft.VisualStudio.TestTools.UnitTesting;
Expand Down Expand Up @@ -1014,4 +1017,96 @@ public void Test_ReadOnlySpan2DT_ReadOnlyRefEnumerable_Cast()

CollectionAssert.AreEqual(result, row);
}

#if NET6_0_OR_GREATER
[TestMethod]
public void Test_ReadOnlySpan2DT_FromMemoryManager_Indexing()
{
const int w = 10;
const int h = 10;
const int l = w * h;

byte[] b = new byte[l];
short[] s = new short[l];

for (int i = 0; i < l; ++i)
{
b[i] = (byte)i;
s[i] = (short)i;
}

Memory2DTester<byte> byteTester = new(w, h, b);
Span2D<byte> byteSpan2DFromArray = byteTester.GetMemory2DFromArray().Span;

Assert.AreEqual(11, byteSpan2DFromArray[0, 0]);

Span2D<byte> byteSpan2DFromMemoryManager = byteTester.GetMemory2DFromMemoryManager().Span;

Assert.AreEqual(11, byteSpan2DFromMemoryManager[0, 0]);

Memory2DTester<short> shortTester = new(w, h, s);
Span2D<short> shortSpan2DFromArray = shortTester.GetMemory2DFromArray().Span;
Span2D<short> shortSpan2DFromMemoryManager = shortTester.GetMemory2DFromMemoryManager().Span;

Assert.AreEqual(11, shortSpan2DFromArray[0, 0]);
Assert.AreEqual(11, shortSpan2DFromMemoryManager[0, 0]);
}
#endif
}

#if NET6_0_OR_GREATER
public sealed class Memory2DTester<T> : MemoryManager<T>
where T : unmanaged
{
private readonly T[] data;

public Memory2DTester(int w, int h, T[] data)
{
if (w < 2 || h < 2)
{
throw new ArgumentException("The 'w' and 'h' arguments must be at least 2.");
}

this.data = data;

Width = w;
Height = h;
}

public int Width { get; }

public int Height { get; }

public Memory2D<T> GetMemory2DFromMemoryManager()
{
return new(this, Width + 1, Height - 1, Width - 1, 1);
}

public Memory2D<T> GetMemory2DFromArray()
{
return new(this.data, Width + 1, Height - 1, Width - 1, 1);
}

/// <inheritdoc/>
public override Span<T> GetSpan()
{
return new(this.data);
}

/// <inheritdoc/>
public override MemoryHandle Pin(int elementIndex = 0)
{
return default;
}

/// <inheritdoc/>
public override void Unpin()
{
}

/// <inheritdoc/>
protected override void Dispose(bool disposing)
{
}
}
#endif