Skip to content
Closed
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 272 additions & 16 deletions src/Microsoft.ML.Dnn/TensorTypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,305 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Internal.Utilities;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow;

#if _REGEN_GLOBAL
%supported_numericals = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Double","Single"]
%supported_numericals_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","double","float"]
%supported_numericals_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_DOUBLE","TF_FLOAT"]
%supported_numericals_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]
#endif

namespace Microsoft.ML.Transforms
{
[BestFriend]
internal static class TensorTypeExtensions
{
public static void ToScalar<T>(this Tensor tensor, ref T dst) where T : unmanaged
{

@codemzs codemzs Oct 3, 2019

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this handle string tensors? #Closed

@Nucs Nucs Oct 3, 2019

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was sure I had String-Tensor -> T conversion implemented, Added it. #Closed

if (typeof(T).as_dtype() != tensor.dtype)
throw new NotSupportedException();

unsafe
{
dst = *(T*)tensor.buffer;
}
if (typeof(T).as_dtype() == tensor.dtype && tensor.dtype != TF_DataType.TF_STRING)
{
dst = *(T*) tensor.buffer;
return;
}

//TODO When upgrading to the newest version of Tensorflow.NET, NumSharp will consequently upgrade too, just remove the second argument of the ChangeType calls.
switch (tensor.dtype)
{
#if _REGEN
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
case TF_DataType.#1:
dst = Converts.ChangeType<T>(*(#3*) tensor.buffer, NPTypeCode.#2);
return;
%
#else

case TF_DataType.TF_BOOL:
dst = Converts.ChangeType<T>(*(bool*) tensor.buffer, NPTypeCode.Boolean);
return;
case TF_DataType.TF_UINT8:
dst = Converts.ChangeType<T>(*(byte*) tensor.buffer, NPTypeCode.Byte);
return;
case TF_DataType.TF_INT16:
dst = Converts.ChangeType<T>(*(short*) tensor.buffer, NPTypeCode.Int16);
return;
case TF_DataType.TF_UINT16:
dst = Converts.ChangeType<T>(*(ushort*) tensor.buffer, NPTypeCode.UInt16);
return;
case TF_DataType.TF_INT32:
dst = Converts.ChangeType<T>(*(int*) tensor.buffer, NPTypeCode.Int32);
return;
case TF_DataType.TF_UINT32:
dst = Converts.ChangeType<T>(*(uint*) tensor.buffer, NPTypeCode.UInt32);
return;
case TF_DataType.TF_INT64:
dst = Converts.ChangeType<T>(*(long*) tensor.buffer, NPTypeCode.Int64);
return;
case TF_DataType.TF_UINT64:
dst = Converts.ChangeType<T>(*(ulong*) tensor.buffer, NPTypeCode.UInt64);
return;
case TF_DataType.TF_DOUBLE:
dst = Converts.ChangeType<T>(*(double*) tensor.buffer, NPTypeCode.Double);
return;
case TF_DataType.TF_FLOAT:
dst = Converts.ChangeType<T>(*(float*) tensor.buffer, NPTypeCode.Single);
return;
#endif
case TF_DataType.TF_STRING:
dst = Converts.ChangeType<T>(tensor.StringData()[0], NPTypeCode.String);
return;
case TF_DataType.TF_COMPLEX64:
case TF_DataType.TF_COMPLEX128:
default:
throw new NotSupportedException();
}
}
}

public static void CopyTo<T>(this Tensor tensor, Span<T> values) where T: unmanaged
public static void CopyTo<T>(this Tensor tensor, Span<T> destination) where T : unmanaged

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public static void CopyTo(this Tensor tensor, Span destination) where T : unmanaged [](start = 8, length = 89)

There is a lot of unsafe code in here that does memory copy/manipulation, have you considered using Span.CopyTo(Span) and Span.Slice(int,int) to achieve the same?

@Nucs Nucs Oct 3, 2019

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a point, I perform all the checks necessary beforehand. The code is safe. Any use of Span will result in unnecessary overhead.
Plus, is there a way to perform cast from one dtype to an other using only Span?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codemzs I can replace:
var sdst = (bool*) Unsafe.AsPointer(ref destination.GetPinnableReference());
With:
fixed (byte* sdst = MemoryMarshal.Cast<T, byte>(destination))

They do a similar logic just with more overhead.

{
if (typeof(T).as_dtype() != tensor.dtype)
throw new NotSupportedException();

unsafe
{
var len = checked((int)tensor.size);
var src = (T*)tensor.buffer;
var span = new Span<T>(src, len);
span.CopyTo(values);
var len = checked((int) tensor.size);
//perform regular CopyTo using Span.CopyTo.
if (typeof(T).as_dtype() == tensor.dtype && tensor.dtype != TF_DataType.TF_STRING) //T can't be a string but tensor can.
{
var src = (T*) tensor.buffer;
var srcSpan = new Span<T>(src, len);
srcSpan.CopyTo(destination);

return;
}

if (len > destination.Length)
throw new ArgumentException("Destinion was too short to perform CopyTo.");

//Perform cast to type <T>.
fixed (T* dst_ = destination)
{
var dst = dst_;
switch (tensor.dtype)
{
#if _REGEN
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
case TF_DataType.#1:
{
var converter = Converts.FindConverter<#3, T>();
var src = (#3*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
%
#else
case TF_DataType.TF_BOOL:
{
var converter = Converts.FindConverter<bool, T>();
var src = (bool*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_UINT8:
{
var converter = Converts.FindConverter<byte, T>();
var src = (byte*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_INT16:
{
var converter = Converts.FindConverter<short, T>();
var src = (short*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_UINT16:
{
var converter = Converts.FindConverter<ushort, T>();
var src = (ushort*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_INT32:
{
var converter = Converts.FindConverter<int, T>();
var src = (int*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_UINT32:
{
var converter = Converts.FindConverter<uint, T>();
var src = (uint*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_INT64:
{
var converter = Converts.FindConverter<long, T>();
var src = (long*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_UINT64:
{
var converter = Converts.FindConverter<ulong, T>();
var src = (ulong*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_DOUBLE:
{
var converter = Converts.FindConverter<double, T>();
var src = (double*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
case TF_DataType.TF_FLOAT:
{
var converter = Converts.FindConverter<float, T>();
var src = (float*) tensor.buffer;
for (var i = 0; i < len; i++)
*(dst + i) = converter(unchecked(*(src + i)));
return;
}
#endif
case TF_DataType.TF_STRING:
{
var src = tensor.StringData();
var culture = CultureInfo.InvariantCulture;

switch (typeof(T).as_dtype())
{
#if _REGEN
%foreach supported_numericals_TF_DataType,supported_numericals,supported_numericals_lowercase%
case TF_DataType.#1: {
var sdst = (#3*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).To#2(culture);
return;
}
%
#else
case TF_DataType.TF_BOOL: {
var sdst = (bool*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToBoolean(culture);
return;
}
case TF_DataType.TF_UINT8: {
var sdst = (byte*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToByte(culture);
return;
}
case TF_DataType.TF_INT16: {
var sdst = (short*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToInt16(culture);
return;
}
case TF_DataType.TF_UINT16: {
var sdst = (ushort*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToUInt16(culture);
return;
}
case TF_DataType.TF_INT32: {
var sdst = (int*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToInt32(culture);
return;
}
case TF_DataType.TF_UINT32: {
var sdst = (uint*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToUInt32(culture);
return;
}
case TF_DataType.TF_INT64: {
var sdst = (long*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToInt64(culture);
return;
}
case TF_DataType.TF_UINT64: {
var sdst = (ulong*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToUInt64(culture);
return;
}
case TF_DataType.TF_DOUBLE: {
var sdst = (double*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToDouble(culture);
return;
}
case TF_DataType.TF_FLOAT: {
var sdst = (float*)Unsafe.AsPointer(ref destination.GetPinnableReference());
for (var i = 0; i < len; i++)
*(sdst + i) = ((IConvertible)src[i]).ToSingle(culture);
return;
}
#endif
default:
throw new NotSupportedException();
}
}
case TF_DataType.TF_COMPLEX64:
case TF_DataType.TF_COMPLEX128:
default:
throw new NotSupportedException();
}
}
}
}

public static void ToArray<T>(this Tensor tensor, ref T[] array) where T : unmanaged
{
Utils.EnsureSize(ref array, (int)tensor.size, (int)tensor.size, false);
Utils.EnsureSize(ref array, (int) tensor.size, (int) tensor.size, false);
var span = new Span<T>(array);

CopyTo(tensor, span);
}
}
}
}