diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 850ad307f7c4fb..41a1c32987c353 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -55,21 +55,28 @@ namespace System.Buffers } namespace System.Numerics.Tensors { - [System.Diagnostics.CodeAnalysis.Experimental("SYSLIB5001", UrlFormat = "https://aka.ms/dotnet-warnings/{0}")] - public partial interface IReadOnlyTensor : System.Collections.Generic.IEnumerable, System.Collections.IEnumerable where TSelf : System.Numerics.Tensors.IReadOnlyTensor + [System.Diagnostics.CodeAnalysis.ExperimentalAttribute("SYSLIB5001", UrlFormat="https://aka.ms/dotnet-warnings/{0}")] + public partial interface IReadOnlyTensor { - static abstract TSelf? Empty { get; } nint FlattenedLength { get; } bool IsEmpty { get; } bool IsPinned { get; } - T this[params scoped System.ReadOnlySpan indexes] { get; } - TSelf this[params scoped System.ReadOnlySpan ranges] { get; } - T this[params scoped System.ReadOnlySpan indexes] { get; } + object this[params scoped System.ReadOnlySpan indexes] { get; } + object this[params scoped System.ReadOnlySpan indexes] { get; } [System.Diagnostics.CodeAnalysis.UnscopedRefAttribute] System.ReadOnlySpan Lengths { get; } int Rank { get; } [System.Diagnostics.CodeAnalysis.UnscopedRefAttribute] System.ReadOnlySpan Strides { get; } + System.Buffers.MemoryHandle GetPinnedHandle(); + } + [System.Diagnostics.CodeAnalysis.ExperimentalAttribute("SYSLIB5001", UrlFormat="https://aka.ms/dotnet-warnings/{0}")] + public partial interface IReadOnlyTensor : System.Collections.Generic.IEnumerable, System.Collections.IEnumerable, System.Numerics.Tensors.IReadOnlyTensor where TSelf : System.Numerics.Tensors.IReadOnlyTensor + { + static abstract TSelf? Empty { get; } + new T this[params scoped System.ReadOnlySpan indexes] { get; } + TSelf this[params scoped System.ReadOnlySpan ranges] { get; } + new T this[params scoped System.ReadOnlySpan indexes] { get; } System.Numerics.Tensors.ReadOnlyTensorSpan AsReadOnlyTensorSpan(); System.Numerics.Tensors.ReadOnlyTensorSpan AsReadOnlyTensorSpan(params scoped System.ReadOnlySpan startIndex); System.Numerics.Tensors.ReadOnlyTensorSpan AsReadOnlyTensorSpan(params scoped System.ReadOnlySpan range); @@ -83,10 +90,18 @@ public partial interface IReadOnlyTensor : System.Collections.Generic. bool TryCopyTo(scoped System.Numerics.Tensors.TensorSpan destination); bool TryFlattenTo(scoped System.Span destination); } - [System.Diagnostics.CodeAnalysis.Experimental("SYSLIB5001", UrlFormat = "https://aka.ms/dotnet-warnings/{0}")] - public partial interface ITensor : System.Collections.Generic.IEnumerable, System.Collections.IEnumerable, System.Numerics.Tensors.IReadOnlyTensor where TSelf : System.Numerics.Tensors.ITensor + [System.Diagnostics.CodeAnalysis.ExperimentalAttribute("SYSLIB5001", UrlFormat="https://aka.ms/dotnet-warnings/{0}")] + public partial interface ITensor : System.Numerics.Tensors.IReadOnlyTensor { bool IsReadOnly { get; } + new object this[params scoped System.ReadOnlySpan indexes] { get; set; } + new object this[params scoped System.ReadOnlySpan indexes] { get; set; } + void Clear(); + void Fill(object value); + } + [System.Diagnostics.CodeAnalysis.ExperimentalAttribute("SYSLIB5001", UrlFormat="https://aka.ms/dotnet-warnings/{0}")] + public partial interface ITensor : System.Collections.Generic.IEnumerable, System.Collections.IEnumerable, System.Numerics.Tensors.IReadOnlyTensor, System.Numerics.Tensors.IReadOnlyTensor, System.Numerics.Tensors.ITensor where TSelf : System.Numerics.Tensors.ITensor + { new T this[params scoped System.ReadOnlySpan indexes] { get; set; } new TSelf this[params scoped System.ReadOnlySpan ranges] { get; set; } new T this[params scoped System.ReadOnlySpan indexes] { get; set; } @@ -94,7 +109,6 @@ public partial interface ITensor : System.Collections.Generic.IEnumera System.Numerics.Tensors.TensorSpan AsTensorSpan(params scoped System.ReadOnlySpan startIndex); System.Numerics.Tensors.TensorSpan AsTensorSpan(params scoped System.ReadOnlySpan range); System.Numerics.Tensors.TensorSpan AsTensorSpan(params scoped System.ReadOnlySpan start); - void Clear(); static abstract TSelf Create(scoped System.ReadOnlySpan lengths, bool pinned = false); static abstract TSelf Create(scoped System.ReadOnlySpan lengths, scoped System.ReadOnlySpan strides, bool pinned = false); static abstract TSelf CreateUninitialized(scoped System.ReadOnlySpan lengths, bool pinned = false); @@ -816,8 +830,8 @@ public ref partial struct Enumerator public bool MoveNext() { throw null; } } } - [System.Diagnostics.CodeAnalysis.Experimental("SYSLIB5001", UrlFormat = "https://aka.ms/dotnet-warnings/{0}")] - public sealed partial class Tensor : System.Collections.Generic.IEnumerable, System.Collections.IEnumerable, System.Numerics.Tensors.IReadOnlyTensor, T>, System.Numerics.Tensors.ITensor, T> + [System.Diagnostics.CodeAnalysis.ExperimentalAttribute("SYSLIB5001", UrlFormat="https://aka.ms/dotnet-warnings/{0}")] + public sealed partial class Tensor : System.Collections.Generic.IEnumerable, System.Collections.IEnumerable, System.Numerics.Tensors.IReadOnlyTensor, System.Numerics.Tensors.IReadOnlyTensor, T>, System.Numerics.Tensors.ITensor, System.Numerics.Tensors.ITensor, T> { internal Tensor() { } public static System.Numerics.Tensors.Tensor Empty { get { throw null; } } @@ -831,12 +845,16 @@ internal Tensor() { } public System.ReadOnlySpan Lengths { get { throw null; } } public int Rank { get { throw null; } } public System.ReadOnlySpan Strides { get { throw null; } } + object System.Numerics.Tensors.IReadOnlyTensor.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } } + object System.Numerics.Tensors.IReadOnlyTensor.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } } + System.ReadOnlySpan System.Numerics.Tensors.IReadOnlyTensor.Lengths { get { throw null; } } + System.ReadOnlySpan System.Numerics.Tensors.IReadOnlyTensor.Strides { get { throw null; } } T System.Numerics.Tensors.IReadOnlyTensor, T>.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } } System.Numerics.Tensors.Tensor System.Numerics.Tensors.IReadOnlyTensor, T>.this[params scoped System.ReadOnlySpan ranges] { get { throw null; } } T System.Numerics.Tensors.IReadOnlyTensor, T>.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } } - System.ReadOnlySpan System.Numerics.Tensors.IReadOnlyTensor, T>.Lengths { get { throw null; } } - System.ReadOnlySpan System.Numerics.Tensors.IReadOnlyTensor, T>.Strides { get { throw null; } } - bool System.Numerics.Tensors.ITensor, T>.IsReadOnly { get { throw null; } } + bool System.Numerics.Tensors.ITensor.IsReadOnly { get { throw null; } } + object System.Numerics.Tensors.ITensor.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } set { } } + object System.Numerics.Tensors.ITensor.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } set { } } T System.Numerics.Tensors.ITensor, T>.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } set { } } T System.Numerics.Tensors.ITensor, T>.this[params scoped System.ReadOnlySpan indexes] { get { throw null; } set { } } public System.Numerics.Tensors.ReadOnlyTensorSpan AsReadOnlyTensorSpan() { throw null; } @@ -849,12 +867,14 @@ internal Tensor() { } public System.Numerics.Tensors.TensorSpan AsTensorSpan(params scoped System.ReadOnlySpan start) { throw null; } public void Clear() { } public void CopyTo(scoped System.Numerics.Tensors.TensorSpan destination) { } + public void Fill(object value) { } public void Fill(T value) { } public void FlattenTo(scoped System.Span destination) { } public System.Collections.Generic.IEnumerator GetEnumerator() { throw null; } public override int GetHashCode() { throw null; } [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] public ref T GetPinnableReference() { throw null; } + public System.Buffers.MemoryHandle GetPinnedHandle() { throw null; } public static implicit operator System.Numerics.Tensors.ReadOnlyTensorSpan (System.Numerics.Tensors.Tensor value) { throw null; } public static implicit operator System.Numerics.Tensors.TensorSpan (System.Numerics.Tensors.Tensor value) { throw null; } public static implicit operator System.Numerics.Tensors.Tensor (T[] array) { throw null; } diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor.cs index 9d42fa96d3f8af..ce52ab8c279b91 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor.cs @@ -7,20 +7,13 @@ namespace System.Numerics.Tensors { + /// /// Represents a read-only tensor. /// - /// The type that implements this interface. - /// The element type. [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)] - public interface IReadOnlyTensor : IEnumerable - where TSelf : IReadOnlyTensor + public interface IReadOnlyTensor { - /// - /// Gets an empty tensor. - /// - static abstract TSelf? Empty { get; } - /// /// Gets a value that indicates whether the collection is currently empty. /// @@ -41,17 +34,62 @@ public interface IReadOnlyTensor : IEnumerable /// int Rank { get; } + /// + /// Gets the length of each dimension in the tensor. + /// + [UnscopedRef] + ReadOnlySpan Lengths { get; } + + /// + /// Gets the stride of each dimension in the tensor. + /// + [UnscopedRef] + ReadOnlySpan Strides { get; } + + /// + /// Gets the value at the specified indexes. + /// + /// The indexes to be used. + object this[params scoped ReadOnlySpan indexes] { get; } + + /// + /// Gets the value at the specified indexes. + /// + /// The indexes to be used. + object this[params scoped ReadOnlySpan indexes] { get; } + + /// + /// Pins and gets a to the backing memory. + /// + /// + MemoryHandle GetPinnedHandle(); + } + + /// + /// Represents a read-only tensor. + /// + /// The type that implements this interface. + /// The element type. + [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)] + public interface IReadOnlyTensor : IReadOnlyTensor, IEnumerable + where TSelf : IReadOnlyTensor + { + /// + /// Gets an empty tensor. + /// + static abstract TSelf? Empty { get; } + /// /// Gets the value at the specified indexes. /// /// The indexes to be used. - T this[params scoped ReadOnlySpan indexes] { get; } + new T this[params scoped ReadOnlySpan indexes] { get; } /// /// Gets the value at the specified indexes. /// /// The indexes to be used. - T this[params scoped ReadOnlySpan indexes] { get; } + new T this[params scoped ReadOnlySpan indexes] { get; } /// /// Gets the values at the specified ranges. @@ -98,18 +136,6 @@ public interface IReadOnlyTensor : IEnumerable /// The destination span where the data should be flattened to. void FlattenTo(scoped Span destination); - /// - /// Gets the length of each dimension in the tensor. - /// - [UnscopedRef] - ReadOnlySpan Lengths { get; } - - /// - /// Gets the stride of each dimension in the tensor. - /// - [UnscopedRef] - ReadOnlySpan Strides { get; } - /// /// Returns a reference to the 0th element of the tensor. If the tensor is empty, returns . /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor.cs index eba3af96f69df1..4a1cbd2b1bf800 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor.cs @@ -6,14 +6,47 @@ namespace System.Numerics.Tensors { + /// + /// Represents a tensor. + /// + [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)] + public interface ITensor : IReadOnlyTensor + { + /// + /// Gets the value at the specified indexes. + /// + /// The indexes to be used. + new object this[params scoped ReadOnlySpan indexes] { get; set; } + + /// + /// Gets the value at the specified indexes. + /// + /// The indexes to be used. + new object this[params scoped ReadOnlySpan indexes] { get; set; } + + /// + /// Gets a value that indicates whether the collection is read-only. + /// + bool IsReadOnly { get; } + + /// + /// Clears the tensor. + /// + void Clear(); + + /// + /// Fills the contents of this tensor with the given value. + /// + void Fill(object value); + } + /// /// Represents a tensor. /// /// The type that implements this interface. /// The element type. [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)] - public interface ITensor - : IReadOnlyTensor + public interface ITensor : ITensor, IReadOnlyTensor where TSelf : ITensor { // TODO: Determine if we can implement `IEqualityOperators`. @@ -67,11 +100,6 @@ public interface ITensor /// static abstract TSelf CreateUninitialized(scoped ReadOnlySpan lengths, scoped ReadOnlySpan strides, bool pinned = false); - /// - /// Gets a value that idicates whether the collection is read-only. - /// - bool IsReadOnly { get; } - /// /// Gets the value at the specified indexes. /// @@ -117,11 +145,6 @@ public interface ITensor /// The converted . TensorSpan AsTensorSpan(params scoped ReadOnlySpan range); - /// - /// Clears the tensor. - /// - void Clear(); - /// /// Fills the contents of this tensor with the given value. /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs index 29e3cd48f60fa2..3596a04cdfc44e 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Tensor.cs @@ -21,8 +21,7 @@ namespace System.Numerics.Tensors /// Represents a tensor. /// [Experimental(Experimentals.TensorTDiagId, UrlFormat = Experimentals.SharedUrlFormat)] - public sealed class Tensor - : ITensor, T> + public sealed class Tensor : ITensor, ITensor, T> { /// A byref or a native ptr. internal readonly T[] _values; @@ -176,7 +175,7 @@ static Tensor ITensor, T>.CreateUninitialized(scoped ReadOnlySpan. /// /// with the lengths of each dimension. - ReadOnlySpan IReadOnlyTensor, T>.Lengths => _lengths; + ReadOnlySpan IReadOnlyTensor.Lengths => _lengths; /// @@ -189,9 +188,16 @@ static Tensor ITensor, T>.CreateUninitialized(scoped ReadOnlySpan. /// /// with the strides of each dimension. - ReadOnlySpan IReadOnlyTensor, T>.Strides => _strides; + ReadOnlySpan IReadOnlyTensor.Strides => _strides; - bool ITensor, T>.IsReadOnly => false; + bool ITensor.IsReadOnly => false; + + object IReadOnlyTensor.this[params scoped ReadOnlySpan indexes] => this[indexes]!; + + object IReadOnlyTensor.this[params scoped ReadOnlySpan indexes] => this[indexes]!; + + object ITensor.this[params scoped ReadOnlySpan indexes] { get => this[indexes]!; set => this[indexes] = (T)value; } + object ITensor.this[params scoped ReadOnlySpan indexes] { get => this[indexes]!; set => this[indexes] = (T)value; } /// /// Returns a reference to specified element of the Tensor. @@ -524,6 +530,11 @@ public Tensor Slice(params ReadOnlySpan startIndex) [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Fill(T value) => AsTensorSpan().Fill(value); + /// + /// Fills the contents of this span with the given value. + /// + public void Fill(object value) => Fill(value is T t ? t : throw new ArgumentException($"Cannot convert {value} to {typeof(T)}")); + /// /// Copies the contents of this tensor into destination span. If the source /// and destinations overlap, this method behaves as if the original values in @@ -671,5 +682,18 @@ public string ToString(params ReadOnlySpan maximumLengths) sb.AppendLine("}"); return sb.ToString(); } + + /// + /// Pins and gets a to the backing memory. + /// + /// A which has pinned the backing memory. + public MemoryHandle GetPinnedHandle() + { + GCHandle handle = GCHandle.Alloc(_values, GCHandleType.Pinned); + unsafe + { + return new MemoryHandle(Unsafe.AsPointer(ref GetPinnableReference()), handle); + } + } } } diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs index 607d09788db1ce..608d8a85506614 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; @@ -2644,5 +2645,73 @@ public void TensorFilteredUpdateTest() Tensor.FilteredUpdate(tensor1.AsTensorSpan(), filter, replace); Assert.Equal(new int[] { -1, -1, -1 }, tensor1.ToArray()); } + + [Fact] + public void TensorObjectFillTests() + { + ITensor tensor = (ITensor)new Tensor(new int[4], new nint[] { 2, 2 }); + tensor.Fill(5); + + Assert.Equal(5, tensor[0, 0]); + Assert.Equal(5, tensor[0, 1]); + Assert.Equal(5, tensor[1, 0]); + Assert.Equal(5, tensor[1, 1]); + + Assert.Throws(() => tensor.Fill("invalid")); + Assert.Throws(() => tensor.Fill(null)); + + tensor.Fill((object)5); + Assert.Equal(5, tensor[0, 0]); + Assert.Equal(5, tensor[0, 1]); + Assert.Equal(5, tensor[1, 0]); + Assert.Equal(5, tensor[1, 1]); + } + + [Fact] + public void TensorObjectIndexerTests() + { + ITensor tensor = new Tensor(new int[] { 1, 2, 3, 4 }, new nint[] { 2, 2 }); + + Assert.Equal(1, tensor[new nint[] { 0, 0 }]); + Assert.Equal(2, tensor[new nint[] { 0, 1 }]); + Assert.Equal(3, tensor[new nint[] { 1, 0 }]); + Assert.Equal(4, tensor[new nint[] { 1, 1 }]); + + tensor[new nint[] { 0, 0 }] = 10; + tensor[new nint[] { 0, 1 }] = 20; + tensor[new nint[] { 1, 0 }] = 30; + tensor[new nint[] { 1, 1 }] = 40; + + Assert.Equal(10, tensor[new nint[] { 0, 0 }]); + Assert.Equal(20, tensor[new nint[] { 0, 1 }]); + Assert.Equal(30, tensor[new nint[] { 1, 0 }]); + Assert.Equal(40, tensor[new nint[] { 1, 1 }]); + + Assert.Throws(() => tensor[new nint[] { 2, 0 }]); + Assert.Throws(() => tensor[new nint[] { 0, 2 }]); + Assert.Throws(() => tensor[new nint[] { -1, 0 }]); + Assert.Throws(() => tensor[new nint[] { -1, -1 }]); + + Assert.Throws(() => tensor[new nint[] { 2, 0 }] = 10); + Assert.Throws(() => tensor[new nint[] { 0, 2 }] = 20); + Assert.Throws(() => tensor[new nint[] { -1, 0 }] = 20); + Assert.Throws(() => tensor[new nint[] { -1, -1 }] = 20); + } + + [Fact] + public void TensorGetPinnedHandleTests() + { + Tensor tensor = new Tensor(new int[] { 1, 2, 3, 4 }, new nint[] { 2, 2 }); + + using MemoryHandle handle = tensor.GetPinnedHandle(); + unsafe + { + int* ptr = (int*)handle.Pointer; + Assert.Equal(1, ptr[0]); + Assert.Equal(2, ptr[1]); + Assert.Equal(3, ptr[2]); + Assert.Equal(4, ptr[3]); + } + } } }