Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions onnxruntime/test/contrib_ops/tensor_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ void MeanVarianceNormalizationAcrossChannels(bool across_channels, bool normaliz
test.AddAttribute("normalize_variance", normalize_variance ? one : zero);
test.AddInput<float>("input", {N, C, H, W}, X);
test.AddOutput<float>("output", {N, C, H, W}, result);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kTensorrtExecutionProvider}); // OpenVINO doesn't support MVN operator below opset 9. TensorRT doesn't support opset 8 of MVN operator.
// DML currently has known failures in this 4D MVN coverage.
Comment thread
hariharans29 marked this conversation as resolved.
// See https://github.com/microsoft/onnxruntime/issues/27933 and remove this exclusion once
// that issue is fixed. OpenVINO does not support MVN below opset 9. TensorRT does not
// support MVN opset 8.
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kDmlExecutionProvider, kOpenVINOExecutionProvider, kTensorrtExecutionProvider});
}

void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_variance) {
Expand Down Expand Up @@ -188,7 +193,12 @@ void MeanVarianceNormalizationPerChannel(bool across_channels, bool normalize_va
test.AddAttribute("normalize_variance", normalize_variance ? one : zero);
test.AddInput<float>("input", {N, C, H, W}, X);
test.AddOutput<float>("output", {N, C, H, W}, result);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kTensorrtExecutionProvider}); // OpenVINO doesn't support MVN operator below opset 9. TensorRT doesn't support opset 8 of MVN operator.
// OpenVINO does not support MVN below opset 9. TensorRT does not support MVN opset 8.
// DML currently has known failures in this 4D MVN coverage.
// See https://github.com/microsoft/onnxruntime/issues/27933 and remove this exclusion once
// that issue is fixed.
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kDmlExecutionProvider, kOpenVINOExecutionProvider, kTensorrtExecutionProvider});
}

TEST(MVNContribOpTest, MeanVarianceNormalizationCPUTest_Version1_TO_8) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ TEST(MeanVarianceNormalizationTest, DefaultAxes) {
OpTester test("MeanVarianceNormalization", 9);
test.AddInput<float>("input", {N, C, H, W}, X);
test.AddOutput<float>("output", {N, C, H, W}, result);
test.Run();
// DML currently has known failures in this 4D default-axes MVN coverage.
// See https://github.com/microsoft/onnxruntime/issues/27933 and remove this exclusion once
// that issue is fixed.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
}

static void TestMeanVarianceNormalizationOverAllAxes(const std::vector<int64_t>& shape) {
Expand All @@ -90,7 +93,14 @@ static void TestMeanVarianceNormalizationOverAllAxes(const std::vector<int64_t>&
test.AddInput<float>("input", shape, X);
test.AddOutput<float>("output", shape, Y);

test.Run();
if (shape.size() == 4) {
// Restrict the DML exclusion to the known failing 4D all-axes coverage.
Comment thread
hariharans29 marked this conversation as resolved.
// See https://github.com/microsoft/onnxruntime/issues/27933 and remove this exclusion once
// that issue is fixed.
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider});
} else {
test.Run();
}
}

TEST(MeanVarianceNormalizationTest, AllAxes) {
Expand Down Expand Up @@ -157,6 +167,7 @@ TEST(MeanVarianceNormalizationTest, AxesSubsets5D) {
test.AddOutput<float>("output", shape, Y.data(), Y.size());

if (DefaultDmlExecutionProvider().get() != nullptr) {
// 5D subset-axis coverage stays enabled for DML.
test.SetOutputTolerance(0.001f);
}

Expand Down
Loading