diff --git a/onnxruntime/core/providers/cpu/ml/svmclassifier.h b/onnxruntime/core/providers/cpu/ml/svmclassifier.h index b4927ee35e22f..e392d0915db68 100644 --- a/onnxruntime/core/providers/cpu/ml/svmclassifier.h +++ b/onnxruntime/core/providers/cpu/ml/svmclassifier.h @@ -22,6 +22,7 @@ class SVMCommon { ORT_THROW_IF_ERROR(info.GetAttrs("kernel_params", kernel_params)); if (!kernel_params.empty()) { + ORT_ENFORCE(kernel_params.size() == 3, "kernel_params must be empty or have 3 values not ", kernel_params.size(), "."); gamma_ = kernel_params[0]; coef0_ = kernel_params[1]; degree_ = kernel_params[2]; diff --git a/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc b/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc index 5240c909d2878..2bb504b73ed51 100644 --- a/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc +++ b/onnxruntime/test/providers/cpu/ml/svmclassifier_test.cc @@ -348,5 +348,47 @@ TEST(MLOpTest, SVMClassifierUndersizedProba) { test.Run(OpTester::ExpectResult::kExpectFailure, "prob_a attribute size"); } +TEST(MLOpTest, SVMClassifierDifferentSizeKernelParameters) { + OpTester test("SVMClassifier", 1, onnxruntime::kMLDomain); + + std::vector coefficients = {0.766398549079895f, 0.0871576070785522f, 0.110420741140842f, + -0.963976919651031f}; + std::vector support_vectors = {4.80000019073486f, 3.40000009536743f, 1.89999997615814f, + 5.f, 3.f, 1.60000002384186f, + 4.5f, 2.29999995231628f, 1.29999995231628f, + 5.09999990463257f, 2.5f, 3.f}; + std::vector rho = {2.23510527610779f}; + std::vector kernel_params = {0.122462183237076f, 0.f, 3.f, 0.5f}; // incorrect size + std::vector classes = {0, 1}; + std::vector vectors_per_class = {3, 1}; + + std::vector X = {5.1f, 3.5f, 1.4f, + 4.9f, 3.f, 1.4f, + 4.7f, 3.2f, 1.3f, + 4.6f, 3.1f, 1.5f, + 5.f, 3.6f, 1.4f}; + std::vector scores_predictions = {-1.5556798f, 1.5556798f, + -1.2610321f, 1.2610321f, + -1.5795376f, 1.5795376f, + -1.3083477f, 1.3083477f, + -1.6572928f, 1.6572928f}; + + std::vector class_predictions = {0, 0, 0, 0, 0}; + + test.AddAttribute("kernel_type", std::string("LINEAR")); + test.AddAttribute("coefficients", coefficients); + test.AddAttribute("support_vectors", support_vectors); + test.AddAttribute("vectors_per_class", vectors_per_class); + test.AddAttribute("rho", rho); + test.AddAttribute("kernel_params", kernel_params); + test.AddAttribute("classlabels_ints", classes); + + test.AddInput("X", {5, 3}, X); + test.AddOutput("Y", {5}, class_predictions); + test.AddOutput("Z", {5, 2}, scores_predictions); + + test.Run(OpTester::ExpectResult::kExpectFailure, "kernel_params must be empty or have 3 values"); +} + } // namespace test } // namespace onnxruntime