Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public InferenceTest(ITestOutputHelper o)
[Fact(DisplayName = "TestSessionOptions")]
public void TestSessionOptions()
{
// get instance to setup logging
// get instance to setup logging
var ortEnvInstance = OrtEnv.Instance();

using (SessionOptions opt = new SessionOptions())
Expand Down Expand Up @@ -1938,7 +1938,7 @@ internal static Tuple<InferenceSession, float[], DenseTensor<float>, float[]> Op
var session = (deviceId.HasValue)
? new InferenceSession(model, option)
: new InferenceSession(model);
float[] inputData = TestDataLoader.LoadTensorFromEmbeddedResource("bench.in");
float[] inputData = TestDataLoader.LoadTensorFromEmbeddedResource("bench.in");
float[] expectedOutput = TestDataLoader.LoadTensorFromEmbeddedResource("bench.expected_out");
var inputMeta = session.InputMetadata;
var tensor = new DenseTensor<float>(inputData, inputMeta["data_0"].Dimensions);
Expand All @@ -1961,6 +1961,21 @@ public int GetHashCode(float x)
}
}

internal class DoubleComparer : IEqualityComparer<double>
{
private double atol = 1e-3;
private double rtol = 1.7e-2;

public bool Equals(double x, double y)
{
return Math.Abs(x - y) <= (atol + rtol * Math.Abs(y));
}
public int GetHashCode(double x)
{
return x.GetHashCode();
}
}

class ExactComparer<T> : IEqualityComparer<T>
{
public bool Equals(T x, T y)
Expand Down Expand Up @@ -2069,4 +2084,4 @@ public void Dispose()
}
#endregion
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,10 @@ private void TestPreTrainedModels(string opset, string modelName)
{
Assert.Equal(result.AsTensor<float>(), outputValue.AsTensor<float>(), new FloatComparer());
}
else if (outputMeta.ElementType == typeof(double))
{
Assert.Equal(result.AsTensor<double>(), outputValue.AsTensor<double>(), new DoubleComparer());
}
else if (outputMeta.ElementType == typeof(int))
{
Assert.Equal(result.AsTensor<int>(), outputValue.AsTensor<int>(), new ExactComparer<int>());
Expand Down Expand Up @@ -560,12 +564,12 @@ private void TestPreTrainedModels(string opset, string modelName)
}
else
{
Assert.True(false, "The TestPretrainedModels does not yet support output of type " + nameof(outputMeta.ElementType));
Assert.True(false, $"{nameof(TestPreTrainedModels)} does not yet support output of type {outputMeta.ElementType}");
}
}
else
{
Assert.True(false, "TestPretrainedModel cannot handle non-tensor outputs yet");
Assert.True(false, $"{nameof(TestPreTrainedModels)} cannot handle non-tensor outputs yet");
}
}
}
Expand Down Expand Up @@ -808,4 +812,4 @@ static string GetTestModelsDir()
return modelsDir;
}
}
}
}