Skip to content

Commit

Permalink
matrix classification via tflite model (#656)
Browse files Browse the repository at this point in the history
* cc: matrix_frame as input to graph

- "matrix_frame" as name in order to avoid confusion with matrix.cc
- matrix_frame is a 2D input data modality that gets converted to Eigen::MatrixXf internally
- suited for non-image input to tflite models

* cc: float_vector_frame as input to graph

- "float_vector_frame" as name in order to avoid confusion with FloatArrayPacket
- float_vector_frame is a 1D output data modality that accepts std::vector<float> outputs from a mediapipe / tflite graph
- suited for tflite classification results packaged as vector of floats

* MatrixFramePacket - c# helper functions

- send MatrixData to C++ as byte array

* FloatVectorFramePacket - c# helper functions

- return std::vector<float> from C++ to Unity as List<float>

* Unity: Matrix Classification - Example scene

* Matrix Classification.cs

- driver code for the newly added MatrixFramePacket and FloatVectorFramePacket
- feeds an example matrix of size [ 2 x 3 ] into a mediapipe graph
- the graph runs a simple tflite model (adds +1 to every input)
- then the graph returns the result back to Unity as List<float>

- only tested on Unity-Editor-Mode on Windows 10 Pro

* refactor: rename FloatVectorFrame -> FloatVector

- rename variables
- rename cs files
- rename cc files

* refactor: rename MatrixFrame -> Matrix

- rename variables
- rename cs files
- rename cc files matrix_frame -> matrix_data
-> avoid matrix.cc as it is already used as a name in mediapipe

* move MatrixClassification example scene to Tutorials

- does not represent an official solution
- not sure where else to place this
- MatrixClassification.cs is an important example for showcasing the usage of a tflite model with a matrix data input

* GetArrayPtr() - change access to private

* MatrixPacket: accept MatrixData as input

- before it was byte[]

* add license

* move native functions to Packet_Unsafe

- delete FloatVector_Unsafe.cs

* float_vector.cc -> faster vector allocation

* float_vector.cc remove unused function - delete(...)

* float_vector.h - remove unused headers

* refactor: float_vector.cc

TODO:
- implement GetFloatVector with vector size as argument

* removed unused headers

* refactor: float_vector.h

* refactor: apply autoformatter on cc files

- using format file ".clang-format" in project root

* refactor: mp__MakeMatrixFramePacket_At__PA_i_Rt -> mp__MakeMatrixPacket_At__PKc_i_Rt

* FloatVectorPacketTest added

- build similar to FloatArrayPacketTest
- not yet tested

* fix: float_vector.cc

* fix: MatrixPacket.cs

* fix: Test: FloatVectorPacketTest - Consume_ShouldThrowNotSupportedException

* MatrixPacketTest - add

- all tests involving packet.Get() do not work
- function is not yet implemented

* fix: Make MatrixClassification.cs run on Android

- adding StreamingAssets to ResourceManager

[skip actions]

* Update mediapipe_api/framework/formats/matrix_data.h

Co-authored-by: Junrou Nishida <[email protected]>

* Apply suggestions from code review

Co-authored-by: Junrou Nishida <[email protected]>

* float_vector - return vector size (+2 squashed commit)

Squashed commit:

[e409b05] refactor: vector_float.cc

- naming aligns with files like packet.cc

[bad3cd6] float_vector - return vector size

* fix: matrix_data.cc - wrong func name (+3 squashed commit)

Squashed commit:

[9245a37] fix: Revert "Apply suggestions from code review"

- the below mentioned commit is not working
- return value of inline function is invalid
-> probably due to inline function

This reverts commit c374d61.

[f597e83] fix: remove duplicate cpp func

[6def10c] fix: semicolon omitted

* FloatVectorPacket - replace list by array

fix: FloatVectorPacket

[skip actions] (+1 squashed commits)

Squashed commits:

[69302b1] FloatVectorPacket - replace list by array

- list is slow

* Add license headers

[skip actions]

* Remove Tutorial Scene: MatrixClassification

as per request:
- deleted demo / tutorial scene that showcasts a simple tflite graph
[skip actions]

* fix: MatrixPacket tests

- new GetMatrix function

Caveat:
- MatrixPacket: Consume throws NotSupportedException()
-> not sure if this is a useful test, but such tests exists in similar classes as well

(cherry picked from commit 707af5f454e87312a86b60deaebec18463e47ded)

Co-authored-by: Junrou Nishida <[email protected]>
  • Loading branch information
mgarbade and homuler committed Oct 15, 2022
1 parent e0a68a8 commit 1b1c135
Show file tree
Hide file tree
Showing 18 changed files with 682 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public override float[] Get()
return result;
}

public IntPtr GetArrayPtr()
private IntPtr GetArrayPtr()
{
UnsafeNativeMethods.mp_Packet__GetFloatArray(mpPtr, out var value).Assert();
GC.KeepAlive(this);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Collections.Generic;
using System.Linq;

namespace Mediapipe
{
public class FloatVectorPacket : Packet<float[]>
{
/// <summary>
/// Creates an empty <see cref="FloatVectorPacket" /> instance.
/// </summary>
///

private int _vectorLength = -1;


public FloatVectorPacket() : base(true) { }

[UnityEngine.Scripting.Preserve]
public FloatVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public FloatVectorPacket(float[] value) : base()
{
UnsafeNativeMethods.mp__MakeFloatVectorPacket__PA_i(value, value.Length, out var ptr).Assert();
this.ptr = ptr;
_vectorLength = value.Length;
}

public FloatVectorPacket(float[] value, Timestamp timestamp) : base()
{
UnsafeNativeMethods.mp__MakeFloatVectorPacket_At__PA_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
}

public FloatVectorPacket At(Timestamp timestamp)
{
var packet = At<FloatVectorPacket>(timestamp);
packet._vectorLength = _vectorLength;
return packet;
}

public override float[] Get()
{
UnsafeNativeMethods.mp_Packet__GetFloatVector(mpPtr, out var floatFrameVector, out var size).Assert();
GC.KeepAlive(this);
if (size < 0)
{
throw new InvalidOperationException("The array's length is unknown, set Length first");
}

var result = new float[size];

unsafe
{
var src = (float*)floatFrameVector;

for (var i = 0; i < result.Length; i++)
{
result[i] = *src++;
}
}

return result;
}

public override StatusOr<float[]> Consume()
{
throw new NotSupportedException();
}

public override Status ValidateAsType()
{
UnsafeNativeMethods.mp_Packet__ValidateAsFloatVector(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
return new Status(statusPtr);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using Google.Protobuf;
using System;

namespace Mediapipe
{
public class MatrixPacket : Packet<MatrixData>
{
private int _length = -1;

public int length
{
get => _length;
set
{
if (_length >= 0)
{
throw new InvalidOperationException("Length is already set and cannot be changed");
}

_length = value;
}
}

/// <summary>
/// Creates an empty <see cref="MatrixPacket
/// " /> instance.
/// </summary>
public MatrixPacket() : base(true) { }

[UnityEngine.Scripting.Preserve]
public MatrixPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { }

public MatrixPacket(MatrixData matrixData) : base()
{
var value = matrixData.ToByteArray();
UnsafeNativeMethods.mp__MakeMatrixPacket__PKc_i(value, value.Length, out var ptr).Assert();
this.ptr = ptr;
length = value.Length;
}

public MatrixPacket(MatrixData matrixData, Timestamp timestamp) : base()
{
var value = matrixData.ToByteArray();
UnsafeNativeMethods.mp__MakeMatrixPacket_At__PKc_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert();
GC.KeepAlive(timestamp);
this.ptr = ptr;
length = value.Length;
}

public MatrixPacket At(Timestamp timestamp)
{
var packet = At<MatrixPacket>(timestamp);
packet.length = length;
return packet;
}

public override MatrixData Get()
{
UnsafeNativeMethods.mp_Packet__GetMatrix(mpPtr, out var serializedMatrixData).Assert();
GC.KeepAlive(this);

var matrixData = serializedMatrixData.Deserialize(MatrixData.Parser);
serializedMatrixData.Dispose();

return matrixData;
}

public override StatusOr<MatrixData> Consume()
{
throw new NotSupportedException();
}

public override Status ValidateAsType()
{
UnsafeNativeMethods.mp_Packet__ValidateAsMatrix(mpPtr, out var statusPtr).Assert();

GC.KeepAlive(this);
return new Status(statusPtr);
}

}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System;
using System.Runtime.InteropServices;

namespace Mediapipe
{
internal static partial class UnsafeNativeMethods
{
#region Packet
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeMatrixPacket__PKc_i(byte[] serializedMatrixData, int size, out IntPtr packet_out);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeMatrixPacket_At__PKc_i_Rt(byte[] serializedMatrixData, int size, IntPtr timestamp, out IntPtr packet_out);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsMatrix(IntPtr packet, out IntPtr status);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetMatrix(IntPtr packet, out SerializedProto serializedProto);

#endregion
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ internal static partial class UnsafeNativeMethods
public static extern MpReturnCode mp_Packet__ValidateAsFloat(IntPtr packet, out IntPtr status);
#endregion

#region FloatVector
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatVectorPacket__PA_i(float[] value, int size, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeFloatVectorPacket_At__PA_i_Rt(float[] value, int size, IntPtr timestamp, out IntPtr packet);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__GetFloatVector(IntPtr packet, out IntPtr value, out int size);

[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp_Packet__ValidateAsFloatVector(IntPtr packet, out IntPtr status);
#endregion

#region Int
[DllImport(MediaPipeLibrary, ExactSpelling = true)]
public static extern MpReturnCode mp__MakeIntPacket__i(int value, out IntPtr packet);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) 2021 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System.Collections.Generic;
using NUnit.Framework;
using System;

namespace Mediapipe.Tests
{
public class FloatVectorPacketTest
{
#region Constructor
// [Test, SignalAbort]
// public void Ctor_ShouldInstantiatePacket_When_CalledWithNoArguments()
// {
// using (var packet = new FloatPacket())
// {
//#pragma warning disable IDE0058
// Assert.AreEqual(Status.StatusCode.Internal, packet.ValidateAsType().Code());
// Assert.Throws<MediaPipeException>(() => { packet.Get(); });
// Assert.AreEqual(Timestamp.Unset(), packet.Timestamp());
//#pragma warning restore IDE0058
// }
// }

[Test]
public void Ctor_ShouldInstantiatePacket_When_CalledWithValue()
{
var floatVector = new float[6] { 10, 11, 12, 13, 14, 15 };
using (var packet = new FloatVectorPacket(floatVector))
{
Assert.True(packet.ValidateAsType().Ok());
Assert.AreEqual(floatVector, packet.Get());
Assert.AreEqual(Timestamp.Unset(), packet.Timestamp());
}
}

//[Test]
//public void Ctor_ShouldInstantiatePacket_When_CalledWithValueAndTimestamp()
//{
// using (var timestamp = new Timestamp(1))
// {
// var floatArray = new float[6] { 10, 11, 12, 13, 14, 15 };
// using (var packet = new FloatPacket(floatArray, timestamp))
// {
// Assert.True(packet.ValidateAsType().Ok());
// Assert.AreEqual(0.01f, packet.Get());
// Assert.AreEqual(timestamp, packet.Timestamp());
// }
// }
//}
#endregion

#region #isDisposed
[Test]
public void IsDisposed_ShouldReturnFalse_When_NotDisposedYet()
{
using (var packet = new FloatVectorPacket())
{
Assert.False(packet.isDisposed);
}
}

[Test]
public void IsDisposed_ShouldReturnTrue_When_AlreadyDisposed()
{
var packet = new FloatVectorPacket();
packet.Dispose();

Assert.True(packet.isDisposed);
}
#endregion

#region #At
[Test]
public void At_ShouldReturnNewPacketWithTimestamp()
{
using (var timestamp = new Timestamp(1))
{
var floatVector = new float[6] { 10, 11, 12, 13, 14, 15 };
var packet = new FloatVectorPacket(floatVector).At(timestamp);
Assert.AreEqual(floatVector, packet.Get());
Assert.AreEqual(timestamp, packet.Timestamp());

using (var newTimestamp = new Timestamp(2))
{
var newPacket = packet.At(newTimestamp);
Assert.AreEqual(floatVector, newPacket.Get());
Assert.AreEqual(newTimestamp, newPacket.Timestamp());
}

Assert.AreEqual(timestamp, packet.Timestamp());
}
}
#endregion

#region #Consume
[Test]
public void Consume_ShouldThrowNotSupportedException()
{
using (var packet = new FloatVectorPacket())
{
#pragma warning disable IDE0058
Assert.Throws<NotSupportedException>(() => { packet.Consume(); });
#pragma warning restore IDE0058
}
}
#endregion

// #region #DebugTypeName
// [Test]
// public void DebugTypeName_ShouldReturnFloat_When_ValueIsSet()
// {
// using (var packet = new FloatPacket(0.01f))
// {
// Assert.AreEqual("float", packet.DebugTypeName());
// }
// }
// #endregion
}
}
Loading

0 comments on commit 1b1c135

Please sign in to comment.