Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/cpu/ml/linearclassifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ LinearClassifier::LinearClassifier(const OpKernelInfo& info)

using_strings_ = !classlabels_strings_.empty();
class_count_ = static_cast<ptrdiff_t>(intercepts_.size());

ORT_ENFORCE(class_count_ > 0, "LinearClassifier: intercepts must not be empty.");
ORT_ENFORCE(coefficients_.size() % static_cast<size_t>(class_count_) == 0,
"LinearClassifier: coefficients size (", coefficients_.size(),
") must be a multiple of the number of classes (", class_count_, ").");
Comment thread
vraspar marked this conversation as resolved.

SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions());
}

Expand Down Expand Up @@ -146,6 +152,14 @@ Status LinearClassifier::Compute(OpKernelContext* ctx) const {
input_shape[0])
: narrow<ptrdiff_t>(input_shape[1]);

// Validate coefficients are large enough to prevent OOB read in GEMM.
const size_t expected_coefficients_size = SafeInt<size_t>(class_count_) * SafeInt<size_t>(num_features);
if (coefficients_.size() < expected_coefficients_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"LinearClassifier: coefficients length (", coefficients_.size(),
") is less than classes (", class_count_, ") * features (", num_features, ")");
}

Tensor* Y = ctx->Output(0, {num_batches});

int64_t output_classes = class_count_;
Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/test/providers/cpu/ml/linearclassifer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,47 @@ TEST(MLOpTest, LinearClassifierMulticlassInt32Input) {
TEST(MLOpTest, LinearClassifierMulticlassDoubleInput) {
LinearClassifierMulticlass<double>();
}

// Regression test: coefficients size doesn't match class_count * num_features.
TEST(MLOpTest, LinearClassifierInvalidCoefficientsSizeFails) {
OpTester test("LinearClassifier", 1, onnxruntime::kMLDomain);

// 3 intercepts => class_count = 3, input has 2 features => expects 6 coefficients.
std::vector<float> coefficients = {-0.22562418f, 0.34188559f, 0.68346153f};
std::vector<int64_t> classes = {1, 2, 3};
std::vector<float> intercepts = {-3.91601811f, 0.42575697f, 0.13731251f};

test.AddAttribute("coefficients", coefficients);
test.AddAttribute("intercepts", intercepts);
test.AddAttribute("classlabels_ints", classes);

test.AddInput<float>("X", {1, 2}, {1.f, 0.f});
test.AddOutput<int64_t>("Y", {1}, {0LL});
test.AddOutput<float>("Z", {1, 3}, {0.f, 0.f, 0.f});

test.Run(OpTester::ExpectResult::kExpectFailure,
"LinearClassifier: coefficients length (3) is less than classes (3) * features (2)");
}

// Regression test: coefficients not divisible by class_count.
TEST(MLOpTest, LinearClassifierCoefficientsSizeNotDivisibleByClassCountFails) {
OpTester test("LinearClassifier", 1, onnxruntime::kMLDomain);

// 3 intercepts => class_count = 3, but 5 coefficients is not divisible by 3.
std::vector<float> coefficients = {1.f, 2.f, 3.f, 4.f, 5.f};
std::vector<int64_t> classes = {1, 2, 3};
std::vector<float> intercepts = {0.1f, 0.2f, 0.3f};

test.AddAttribute("coefficients", coefficients);
test.AddAttribute("intercepts", intercepts);
test.AddAttribute("classlabels_ints", classes);

test.AddInput<float>("X", {1, 2}, {1.f, 0.f});
test.AddOutput<int64_t>("Y", {1}, {0LL});
test.AddOutput<float>("Z", {1, 3}, {0.f, 0.f, 0.f});

test.Run(OpTester::ExpectResult::kExpectFailure,
"coefficients size (5) must be a multiple of the number of classes (3)");
}
} // namespace test
} // namespace onnxruntime
Loading