-
-
Notifications
You must be signed in to change notification settings - Fork 465
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
matrix classification via tflite model (#656)
* cc: matrix_frame as input to graph - "matrix_frame" as name in order to avoid confusion with matrix.cc - matrix_frame is a 2D input data modality that gets converted to Eigen::MatrixXf internally - suited for non-image input to tflite models * cc: float_vector_frame as input to graph - "float_vector_frame" as name in order to avoid confusion with FloatArrayPacket - float_vector_frame is a 1D output data modality that accepts std::vector<float> outputs from a mediapipe / tflite graph - suited for tflite classification results packaged as vector of floats * MatrixFramePacket - c# helper functions - send MatrixData to C++ as byte array * FloatVectorFramePacket - c# helper functions - return std::vector<float> from C++ to Unity as List<float> * Unity: Matrix Classification - Example scene * Matrix Classification.cs - driver code for the newly added MatrixFramePacket and FloatVectorFramePacket - feeds an example matrix of size [ 2 x 3 ] into a mediapipe graph - the graph runs a simple tflite model (adds +1 to every input) - then the graph returns the result back to Unity as List<float> - only tested on Unity-Editor-Mode on Windows 10 Pro * refactor: rename FloatVectorFrame -> FloatVector - rename variables - rename cs files - rename cc files * refactor: rename MatrixFrame -> Matrix - rename variables - rename cs files - rename cc files matrix_frame -> matrix_data -> avoid matrix.cc as it is already used as a name in mediapipe * move MatrixClassification example scene to Tutorials - does not represent an official solution - not sure where else to place this - MatrixClassification.cs is an important example for showcasing the usage of a tflite model with a matrix data input * GetArrayPtr() - change access to private * MatrixPacket: accept MatrixData as input - before it was byte[] * add license * move native functions to Packet_Unsafe - delete FloatVector_Unsafe.cs * float_vector.cc -> faster vector allocation * float_vector.cc remove unused function - delete(...) * float_vector.h - remove unused headers * refactor: float_vector.cc TODO: - implement GetFloatVector with vector size as argument * removed unused headers * refactor: float_vector.h * refactor: apply autoformatter on cc files - using format file ".clang-format" in project root * refactor: mp__MakeMatrixFramePacket_At__PA_i_Rt -> mp__MakeMatrixPacket_At__PKc_i_Rt * FloatVectorPacketTest added - build similar to FloatArrayPacketTest - not yet tested * fix: float_vector.cc * fix: MatrixPacket.cs * fix: Test: FloatVectorPacketTest - Consume_ShouldThrowNotSupportedException * MatrixPacketTest - add - all tests involving packet.Get() do not work - function is not yet implemented * fix: Make MatrixClassification.cs run on Android - adding StreamingAssets to ResourceManager [skip actions] * Update mediapipe_api/framework/formats/matrix_data.h Co-authored-by: Junrou Nishida <[email protected]> * Apply suggestions from code review Co-authored-by: Junrou Nishida <[email protected]> * float_vector - return vector size (+2 squashed commit) Squashed commit: [e409b05] refactor: vector_float.cc - naming aligns with files like packet.cc [bad3cd6] float_vector - return vector size * fix: matrix_data.cc - wrong func name (+3 squashed commit) Squashed commit: [9245a37] fix: Revert "Apply suggestions from code review" - the below mentioned commit is not working - return value of inline function is invalid -> probably due to inline function This reverts commit c374d61. [f597e83] fix: remove duplicate cpp func [6def10c] fix: semicolon omitted * FloatVectorPacket - replace list by array fix: FloatVectorPacket [skip actions] (+1 squashed commits) Squashed commits: [69302b1] FloatVectorPacket - replace list by array - list is slow * Add license headers [skip actions] * Remove Tutorial Scene: MatrixClassification as per request: - deleted demo / tutorial scene that showcasts a simple tflite graph [skip actions] * fix: MatrixPacket tests - new GetMatrix function Caveat: - MatrixPacket: Consume throws NotSupportedException() -> not sure if this is a useful test, but such tests exists in similar classes as well (cherry picked from commit 707af5f454e87312a86b60deaebec18463e47ded) Co-authored-by: Junrou Nishida <[email protected]>
- Loading branch information
Showing
18 changed files
with
682 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
86 changes: 86 additions & 0 deletions
86
Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet/FloatVectorPacket.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Copyright (c) 2021 homuler | ||
// | ||
// Use of this source code is governed by an MIT-style | ||
// license that can be found in the LICENSE file or at | ||
// https://opensource.org/licenses/MIT. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
namespace Mediapipe | ||
{ | ||
public class FloatVectorPacket : Packet<float[]> | ||
{ | ||
/// <summary> | ||
/// Creates an empty <see cref="FloatVectorPacket" /> instance. | ||
/// </summary> | ||
/// | ||
|
||
private int _vectorLength = -1; | ||
|
||
|
||
public FloatVectorPacket() : base(true) { } | ||
|
||
[UnityEngine.Scripting.Preserve] | ||
public FloatVectorPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { } | ||
|
||
public FloatVectorPacket(float[] value) : base() | ||
{ | ||
UnsafeNativeMethods.mp__MakeFloatVectorPacket__PA_i(value, value.Length, out var ptr).Assert(); | ||
this.ptr = ptr; | ||
_vectorLength = value.Length; | ||
} | ||
|
||
public FloatVectorPacket(float[] value, Timestamp timestamp) : base() | ||
{ | ||
UnsafeNativeMethods.mp__MakeFloatVectorPacket_At__PA_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert(); | ||
GC.KeepAlive(timestamp); | ||
this.ptr = ptr; | ||
} | ||
|
||
public FloatVectorPacket At(Timestamp timestamp) | ||
{ | ||
var packet = At<FloatVectorPacket>(timestamp); | ||
packet._vectorLength = _vectorLength; | ||
return packet; | ||
} | ||
|
||
public override float[] Get() | ||
{ | ||
UnsafeNativeMethods.mp_Packet__GetFloatVector(mpPtr, out var floatFrameVector, out var size).Assert(); | ||
GC.KeepAlive(this); | ||
if (size < 0) | ||
{ | ||
throw new InvalidOperationException("The array's length is unknown, set Length first"); | ||
} | ||
|
||
var result = new float[size]; | ||
|
||
unsafe | ||
{ | ||
var src = (float*)floatFrameVector; | ||
|
||
for (var i = 0; i < result.Length; i++) | ||
{ | ||
result[i] = *src++; | ||
} | ||
} | ||
|
||
return result; | ||
} | ||
|
||
public override StatusOr<float[]> Consume() | ||
{ | ||
throw new NotSupportedException(); | ||
} | ||
|
||
public override Status ValidateAsType() | ||
{ | ||
UnsafeNativeMethods.mp_Packet__ValidateAsFloatVector(mpPtr, out var statusPtr).Assert(); | ||
|
||
GC.KeepAlive(this); | ||
return new Status(statusPtr); | ||
} | ||
} | ||
} |
11 changes: 11 additions & 0 deletions
11
...s/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet/FloatVectorPacket.cs.meta
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
88 changes: 88 additions & 0 deletions
88
Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet/MatrixPacket.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
// Copyright (c) 2021 homuler | ||
// | ||
// Use of this source code is governed by an MIT-style | ||
// license that can be found in the LICENSE file or at | ||
// https://opensource.org/licenses/MIT. | ||
|
||
using Google.Protobuf; | ||
using System; | ||
|
||
namespace Mediapipe | ||
{ | ||
public class MatrixPacket : Packet<MatrixData> | ||
{ | ||
private int _length = -1; | ||
|
||
public int length | ||
{ | ||
get => _length; | ||
set | ||
{ | ||
if (_length >= 0) | ||
{ | ||
throw new InvalidOperationException("Length is already set and cannot be changed"); | ||
} | ||
|
||
_length = value; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Creates an empty <see cref="MatrixPacket | ||
/// " /> instance. | ||
/// </summary> | ||
public MatrixPacket() : base(true) { } | ||
|
||
[UnityEngine.Scripting.Preserve] | ||
public MatrixPacket(IntPtr ptr, bool isOwner = true) : base(ptr, isOwner) { } | ||
|
||
public MatrixPacket(MatrixData matrixData) : base() | ||
{ | ||
var value = matrixData.ToByteArray(); | ||
UnsafeNativeMethods.mp__MakeMatrixPacket__PKc_i(value, value.Length, out var ptr).Assert(); | ||
this.ptr = ptr; | ||
length = value.Length; | ||
} | ||
|
||
public MatrixPacket(MatrixData matrixData, Timestamp timestamp) : base() | ||
{ | ||
var value = matrixData.ToByteArray(); | ||
UnsafeNativeMethods.mp__MakeMatrixPacket_At__PKc_i_Rt(value, value.Length, timestamp.mpPtr, out var ptr).Assert(); | ||
GC.KeepAlive(timestamp); | ||
this.ptr = ptr; | ||
length = value.Length; | ||
} | ||
|
||
public MatrixPacket At(Timestamp timestamp) | ||
{ | ||
var packet = At<MatrixPacket>(timestamp); | ||
packet.length = length; | ||
return packet; | ||
} | ||
|
||
public override MatrixData Get() | ||
{ | ||
UnsafeNativeMethods.mp_Packet__GetMatrix(mpPtr, out var serializedMatrixData).Assert(); | ||
GC.KeepAlive(this); | ||
|
||
var matrixData = serializedMatrixData.Deserialize(MatrixData.Parser); | ||
serializedMatrixData.Dispose(); | ||
|
||
return matrixData; | ||
} | ||
|
||
public override StatusOr<MatrixData> Consume() | ||
{ | ||
throw new NotSupportedException(); | ||
} | ||
|
||
public override Status ValidateAsType() | ||
{ | ||
UnsafeNativeMethods.mp_Packet__ValidateAsMatrix(mpPtr, out var statusPtr).Assert(); | ||
|
||
GC.KeepAlive(this); | ||
return new Status(statusPtr); | ||
} | ||
|
||
} | ||
} |
11 changes: 11 additions & 0 deletions
11
Packages/com.github.homuler.mediapipe/Runtime/Scripts/Framework/Packet/MatrixPacket.cs.meta
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
29 changes: 29 additions & 0 deletions
29
...homuler.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Framework/Format/Matrix_Unsafe.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (c) 2021 homuler | ||
// | ||
// Use of this source code is governed by an MIT-style | ||
// license that can be found in the LICENSE file or at | ||
// https://opensource.org/licenses/MIT. | ||
|
||
using System; | ||
using System.Runtime.InteropServices; | ||
|
||
namespace Mediapipe | ||
{ | ||
internal static partial class UnsafeNativeMethods | ||
{ | ||
#region Packet | ||
[DllImport(MediaPipeLibrary, ExactSpelling = true)] | ||
public static extern MpReturnCode mp__MakeMatrixPacket__PKc_i(byte[] serializedMatrixData, int size, out IntPtr packet_out); | ||
|
||
[DllImport(MediaPipeLibrary, ExactSpelling = true)] | ||
public static extern MpReturnCode mp__MakeMatrixPacket_At__PKc_i_Rt(byte[] serializedMatrixData, int size, IntPtr timestamp, out IntPtr packet_out); | ||
|
||
[DllImport(MediaPipeLibrary, ExactSpelling = true)] | ||
public static extern MpReturnCode mp_Packet__ValidateAsMatrix(IntPtr packet, out IntPtr status); | ||
|
||
[DllImport(MediaPipeLibrary, ExactSpelling = true)] | ||
public static extern MpReturnCode mp_Packet__GetMatrix(IntPtr packet, out SerializedProto serializedProto); | ||
|
||
#endregion | ||
} | ||
} |
11 changes: 11 additions & 0 deletions
11
...er.mediapipe/Runtime/Scripts/PInvoke/NativeMethods/Framework/Format/Matrix_Unsafe.cs.meta
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
124 changes: 124 additions & 0 deletions
124
...ges/com.github.homuler.mediapipe/Tests/EditMode/Framework/Packet/FloatVectorPacketTest.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// Copyright (c) 2021 homuler | ||
// | ||
// Use of this source code is governed by an MIT-style | ||
// license that can be found in the LICENSE file or at | ||
// https://opensource.org/licenses/MIT. | ||
|
||
using System.Collections.Generic; | ||
using NUnit.Framework; | ||
using System; | ||
|
||
namespace Mediapipe.Tests | ||
{ | ||
public class FloatVectorPacketTest | ||
{ | ||
#region Constructor | ||
// [Test, SignalAbort] | ||
// public void Ctor_ShouldInstantiatePacket_When_CalledWithNoArguments() | ||
// { | ||
// using (var packet = new FloatPacket()) | ||
// { | ||
//#pragma warning disable IDE0058 | ||
// Assert.AreEqual(Status.StatusCode.Internal, packet.ValidateAsType().Code()); | ||
// Assert.Throws<MediaPipeException>(() => { packet.Get(); }); | ||
// Assert.AreEqual(Timestamp.Unset(), packet.Timestamp()); | ||
//#pragma warning restore IDE0058 | ||
// } | ||
// } | ||
|
||
[Test] | ||
public void Ctor_ShouldInstantiatePacket_When_CalledWithValue() | ||
{ | ||
var floatVector = new float[6] { 10, 11, 12, 13, 14, 15 }; | ||
using (var packet = new FloatVectorPacket(floatVector)) | ||
{ | ||
Assert.True(packet.ValidateAsType().Ok()); | ||
Assert.AreEqual(floatVector, packet.Get()); | ||
Assert.AreEqual(Timestamp.Unset(), packet.Timestamp()); | ||
} | ||
} | ||
|
||
//[Test] | ||
//public void Ctor_ShouldInstantiatePacket_When_CalledWithValueAndTimestamp() | ||
//{ | ||
// using (var timestamp = new Timestamp(1)) | ||
// { | ||
// var floatArray = new float[6] { 10, 11, 12, 13, 14, 15 }; | ||
// using (var packet = new FloatPacket(floatArray, timestamp)) | ||
// { | ||
// Assert.True(packet.ValidateAsType().Ok()); | ||
// Assert.AreEqual(0.01f, packet.Get()); | ||
// Assert.AreEqual(timestamp, packet.Timestamp()); | ||
// } | ||
// } | ||
//} | ||
#endregion | ||
|
||
#region #isDisposed | ||
[Test] | ||
public void IsDisposed_ShouldReturnFalse_When_NotDisposedYet() | ||
{ | ||
using (var packet = new FloatVectorPacket()) | ||
{ | ||
Assert.False(packet.isDisposed); | ||
} | ||
} | ||
|
||
[Test] | ||
public void IsDisposed_ShouldReturnTrue_When_AlreadyDisposed() | ||
{ | ||
var packet = new FloatVectorPacket(); | ||
packet.Dispose(); | ||
|
||
Assert.True(packet.isDisposed); | ||
} | ||
#endregion | ||
|
||
#region #At | ||
[Test] | ||
public void At_ShouldReturnNewPacketWithTimestamp() | ||
{ | ||
using (var timestamp = new Timestamp(1)) | ||
{ | ||
var floatVector = new float[6] { 10, 11, 12, 13, 14, 15 }; | ||
var packet = new FloatVectorPacket(floatVector).At(timestamp); | ||
Assert.AreEqual(floatVector, packet.Get()); | ||
Assert.AreEqual(timestamp, packet.Timestamp()); | ||
|
||
using (var newTimestamp = new Timestamp(2)) | ||
{ | ||
var newPacket = packet.At(newTimestamp); | ||
Assert.AreEqual(floatVector, newPacket.Get()); | ||
Assert.AreEqual(newTimestamp, newPacket.Timestamp()); | ||
} | ||
|
||
Assert.AreEqual(timestamp, packet.Timestamp()); | ||
} | ||
} | ||
#endregion | ||
|
||
#region #Consume | ||
[Test] | ||
public void Consume_ShouldThrowNotSupportedException() | ||
{ | ||
using (var packet = new FloatVectorPacket()) | ||
{ | ||
#pragma warning disable IDE0058 | ||
Assert.Throws<NotSupportedException>(() => { packet.Consume(); }); | ||
#pragma warning restore IDE0058 | ||
} | ||
} | ||
#endregion | ||
|
||
// #region #DebugTypeName | ||
// [Test] | ||
// public void DebugTypeName_ShouldReturnFloat_When_ValueIsSet() | ||
// { | ||
// using (var packet = new FloatPacket(0.01f)) | ||
// { | ||
// Assert.AreEqual("float", packet.DebugTypeName()); | ||
// } | ||
// } | ||
// #endregion | ||
} | ||
} |
Oops, something went wrong.