-
Notifications
You must be signed in to change notification settings - Fork 1.9k
TensorTypeExtensions: Added conversion between Tensor to primitive C# types instead of throwing NotSupportedException #4290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
7934428
f067269
effdc8a
1631962
6b38752
14362ea
1e979f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,49 +3,305 @@ | |
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Globalization; | ||
| using System.Runtime.CompilerServices; | ||
| using System.Runtime.InteropServices; | ||
| using System.Text; | ||
| using System.Threading.Tasks; | ||
| using Microsoft.ML.Internal.Utilities; | ||
| using NumSharp.Backends; | ||
| using NumSharp.Backends.Unmanaged; | ||
| using NumSharp.Utilities; | ||
| using Tensorflow; | ||
|
|
||
| #if _REGEN_GLOBAL | ||
| %supported_numericals = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Double","Single"] | ||
| %supported_numericals_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","double","float"] | ||
| %supported_numericals_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_DOUBLE","TF_FLOAT"] | ||
| %supported_numericals_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | ||
| #endif | ||
|
|
||
| namespace Microsoft.ML.Transforms | ||
| { | ||
| [BestFriend] | ||
| internal static class TensorTypeExtensions | ||
| { | ||
| public static void ToScalar<T>(this Tensor tensor, ref T dst) where T : unmanaged | ||
| { | ||
| if (typeof(T).as_dtype() != tensor.dtype) | ||
| throw new NotSupportedException(); | ||
|
|
||
| unsafe | ||
| { | ||
| dst = *(T*)tensor.buffer; | ||
| } | ||
| if (typeof(T).as_dtype() == tensor.dtype && tensor.dtype != TF_DataType.TF_STRING) | ||
| { | ||
| dst = *(T*) tensor.buffer; | ||
| return; | ||
| } | ||
|
|
||
| //TODO When upgrading to the newest version of Tensorflow.NET, NumSharp will consequently upgrade too, just remove the second argument of the ChangeType calls. | ||
| switch (tensor.dtype) | ||
| { | ||
| #if _REGEN | ||
| %foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase% | ||
| case TF_DataType.#1: | ||
| dst = Converts.ChangeType<T>(*(#3*) tensor.buffer, NPTypeCode.#2); | ||
| return; | ||
| % | ||
| #else | ||
|
|
||
| case TF_DataType.TF_BOOL: | ||
| dst = Converts.ChangeType<T>(*(bool*) tensor.buffer, NPTypeCode.Boolean); | ||
| return; | ||
| case TF_DataType.TF_UINT8: | ||
| dst = Converts.ChangeType<T>(*(byte*) tensor.buffer, NPTypeCode.Byte); | ||
| return; | ||
| case TF_DataType.TF_INT16: | ||
| dst = Converts.ChangeType<T>(*(short*) tensor.buffer, NPTypeCode.Int16); | ||
| return; | ||
| case TF_DataType.TF_UINT16: | ||
| dst = Converts.ChangeType<T>(*(ushort*) tensor.buffer, NPTypeCode.UInt16); | ||
| return; | ||
| case TF_DataType.TF_INT32: | ||
| dst = Converts.ChangeType<T>(*(int*) tensor.buffer, NPTypeCode.Int32); | ||
| return; | ||
| case TF_DataType.TF_UINT32: | ||
| dst = Converts.ChangeType<T>(*(uint*) tensor.buffer, NPTypeCode.UInt32); | ||
| return; | ||
| case TF_DataType.TF_INT64: | ||
| dst = Converts.ChangeType<T>(*(long*) tensor.buffer, NPTypeCode.Int64); | ||
| return; | ||
| case TF_DataType.TF_UINT64: | ||
| dst = Converts.ChangeType<T>(*(ulong*) tensor.buffer, NPTypeCode.UInt64); | ||
| return; | ||
| case TF_DataType.TF_DOUBLE: | ||
| dst = Converts.ChangeType<T>(*(double*) tensor.buffer, NPTypeCode.Double); | ||
| return; | ||
| case TF_DataType.TF_FLOAT: | ||
| dst = Converts.ChangeType<T>(*(float*) tensor.buffer, NPTypeCode.Single); | ||
| return; | ||
| #endif | ||
| case TF_DataType.TF_STRING: | ||
| dst = Converts.ChangeType<T>(tensor.StringData()[0], NPTypeCode.String); | ||
| return; | ||
| case TF_DataType.TF_COMPLEX64: | ||
| case TF_DataType.TF_COMPLEX128: | ||
| default: | ||
| throw new NotSupportedException(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| public static void CopyTo<T>(this Tensor tensor, Span<T> values) where T: unmanaged | ||
| public static void CopyTo<T>(this Tensor tensor, Span<T> destination) where T : unmanaged | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is a lot of unsafe code in here that does memory copy/manipulation, have you considered using Span.CopyTo(Span) and Span.Slice(int,int) to achieve the same?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a point, I perform all the checks necessary beforehand. The code is safe. Any use of Span will result in unnecessary overhead.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @codemzs I can replace: They do a similar logic just with more overhead. |
||
| { | ||
| if (typeof(T).as_dtype() != tensor.dtype) | ||
| throw new NotSupportedException(); | ||
|
|
||
| unsafe | ||
| { | ||
| var len = checked((int)tensor.size); | ||
| var src = (T*)tensor.buffer; | ||
| var span = new Span<T>(src, len); | ||
| span.CopyTo(values); | ||
| var len = checked((int) tensor.size); | ||
| //perform regular CopyTo using Span.CopyTo. | ||
| if (typeof(T).as_dtype() == tensor.dtype && tensor.dtype != TF_DataType.TF_STRING) //T can't be a string but tensor can. | ||
| { | ||
| var src = (T*) tensor.buffer; | ||
| var srcSpan = new Span<T>(src, len); | ||
| srcSpan.CopyTo(destination); | ||
|
|
||
| return; | ||
| } | ||
|
|
||
| if (len > destination.Length) | ||
| throw new ArgumentException("Destinion was too short to perform CopyTo."); | ||
|
|
||
| //Perform cast to type <T>. | ||
| fixed (T* dst_ = destination) | ||
| { | ||
| var dst = dst_; | ||
| switch (tensor.dtype) | ||
| { | ||
| #if _REGEN | ||
| %foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase% | ||
| case TF_DataType.#1: | ||
| { | ||
| var converter = Converts.FindConverter<#3, T>(); | ||
| var src = (#3*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| % | ||
| #else | ||
| case TF_DataType.TF_BOOL: | ||
| { | ||
| var converter = Converts.FindConverter<bool, T>(); | ||
| var src = (bool*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT8: | ||
| { | ||
| var converter = Converts.FindConverter<byte, T>(); | ||
| var src = (byte*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_INT16: | ||
| { | ||
| var converter = Converts.FindConverter<short, T>(); | ||
| var src = (short*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT16: | ||
| { | ||
| var converter = Converts.FindConverter<ushort, T>(); | ||
| var src = (ushort*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_INT32: | ||
| { | ||
| var converter = Converts.FindConverter<int, T>(); | ||
| var src = (int*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT32: | ||
| { | ||
| var converter = Converts.FindConverter<uint, T>(); | ||
| var src = (uint*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_INT64: | ||
| { | ||
| var converter = Converts.FindConverter<long, T>(); | ||
| var src = (long*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT64: | ||
| { | ||
| var converter = Converts.FindConverter<ulong, T>(); | ||
| var src = (ulong*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_DOUBLE: | ||
| { | ||
| var converter = Converts.FindConverter<double, T>(); | ||
| var src = (double*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_FLOAT: | ||
| { | ||
| var converter = Converts.FindConverter<float, T>(); | ||
| var src = (float*) tensor.buffer; | ||
| for (var i = 0; i < len; i++) | ||
| *(dst + i) = converter(unchecked(*(src + i))); | ||
| return; | ||
| } | ||
| #endif | ||
| case TF_DataType.TF_STRING: | ||
| { | ||
| var src = tensor.StringData(); | ||
| var culture = CultureInfo.InvariantCulture; | ||
|
|
||
| switch (typeof(T).as_dtype()) | ||
| { | ||
| #if _REGEN | ||
| %foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase% | ||
| case TF_DataType.#1: { | ||
| var sdst = (#3*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).To#2(culture); | ||
| return; | ||
| } | ||
| % | ||
| #else | ||
| case TF_DataType.TF_BOOL: { | ||
| var sdst = (bool*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToBoolean(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT8: { | ||
| var sdst = (byte*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToByte(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_INT16: { | ||
| var sdst = (short*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToInt16(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT16: { | ||
| var sdst = (ushort*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToUInt16(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_INT32: { | ||
| var sdst = (int*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToInt32(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT32: { | ||
| var sdst = (uint*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToUInt32(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_INT64: { | ||
| var sdst = (long*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToInt64(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_UINT64: { | ||
| var sdst = (ulong*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToUInt64(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_DOUBLE: { | ||
| var sdst = (double*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToDouble(culture); | ||
| return; | ||
| } | ||
| case TF_DataType.TF_FLOAT: { | ||
| var sdst = (float*)Unsafe.AsPointer(ref destination.GetPinnableReference()); | ||
| for (var i = 0; i < len; i++) | ||
| *(sdst + i) = ((IConvertible)src[i]).ToSingle(culture); | ||
| return; | ||
| } | ||
| #endif | ||
| default: | ||
| throw new NotSupportedException(); | ||
| } | ||
| } | ||
| case TF_DataType.TF_COMPLEX64: | ||
| case TF_DataType.TF_COMPLEX128: | ||
| default: | ||
| throw new NotSupportedException(); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| public static void ToArray<T>(this Tensor tensor, ref T[] array) where T : unmanaged | ||
| { | ||
| Utils.EnsureSize(ref array, (int)tensor.size, (int)tensor.size, false); | ||
| Utils.EnsureSize(ref array, (int) tensor.size, (int) tensor.size, false); | ||
| var span = new Span<T>(array); | ||
|
|
||
| CopyTo(tensor, span); | ||
| } | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this handle string tensors? #Closed
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was sure I had
String-Tensor -> Tconversion implemented, Added it. #Closed