Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ Status CheckInputs(const T* /*activation*/,
// Group_index shall be 1D of K, or K padded to multiple of block_size
ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size));

// Validate group_index values are within valid range [0, k_blocks)
if (group_index != nullptr && group_index->Location().device.Type() == OrtDevice::CPU) {
auto g_idx_data = group_index->template Data<int32_t>();
auto g_idx_size = static_cast<size_t>(group_index->Shape().Size());
for (size_t i = 0; i < g_idx_size; ++i) {
if (g_idx_data[i] < 0 || g_idx_data[i] >= k_blocks) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"group_index value at index ", i, " is ", g_idx_data[i],
", which is out of valid range [0, ", k_blocks, ")");
}
}
}
Comment thread
vraspar marked this conversation as resolved.
Comment thread
vraspar marked this conversation as resolved.

ASSERT_TENSOR_SHAPE(bias, make_shape(n));

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ __global__ void Dequantize4BitsKernelReOrder(
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1));
for (int i = 0; i < element_per_thread; i++) {
int32_t rid = reorder_idx_with_off[i];
CUDA_KERNEL_ASSERT(rid >= 0 && rid < groups_per_K);
Comment thread
vraspar marked this conversation as resolved.
rid = max(0, min(rid, groups_per_K - 1)); // Clamp for release safety
T scale = *(scale_data + n_idx * scales_shape_x + rid);
uint8_t zp = 8; // Default zero point is 1 << (bits - 1)
if (zero_points) {
Expand Down
89 changes: 89 additions & 0 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,95 @@ TEST(MatMulNBits, Basic_M10_N128_K512) {
}
#endif

// Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT.
// CUDA EP is excluded from these tests, so no risk of hitting CUDA_KERNEL_ASSERT.
TEST(MatMulNBits, InvalidGIdx_OutOfRange) {
constexpr int64_t M = 2, N = 4, K = 32, block_size = 16;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 2
Comment thread
vraspar marked this conversation as resolved.
constexpr int64_t blob_size = block_size * QBits / 8; // 8

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

// A: [M, K]
std::vector<float> a_data(M * K, 1.0f);
test.AddInput<float>("A", {M, K}, a_data, false);

// B: [N, k_blocks, blob_size]
std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);

// scales: [N, k_blocks]
std::vector<float> scales(N * k_blocks, 1.0f);
test.AddInput<float>("scales", {N, k_blocks}, scales, true);

// zero_points: optional (skip)
test.AddOptionalInputEdge<uint8_t>();

// g_idx with out-of-range values (valid range is [0, k_blocks) = [0, 2))
std::vector<int32_t> g_idx(K);
for (int64_t i = 0; i < K; i++) {
g_idx[i] = 99999; // way out of range
}
test.AddInput<int32_t>("g_idx", {K}, g_idx, true);

// bias: optional (skip)
test.AddOptionalInputEdge<float>();

// Output placeholder (won't actually be compared since we expect failure)
std::vector<float> y_data(M * N, 0.0f);
test.AddOutput<float>("Y", {M, N}, y_data);

test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value",
Comment thread
tianleiwu marked this conversation as resolved.
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider,
kOpenVINOExecutionProvider});
}

Comment thread
vraspar marked this conversation as resolved.
// Test that negative g_idx values are rejected.
TEST(MatMulNBits, InvalidGIdx_Negative) {
constexpr int64_t M = 2, N = 4, K = 32, block_size = 16;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size;
constexpr int64_t blob_size = block_size * QBits / 8;

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

std::vector<float> a_data(M * K, 1.0f);
test.AddInput<float>("A", {M, K}, a_data, false);

std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);

std::vector<float> scales(N * k_blocks, 1.0f);
test.AddInput<float>("scales", {N, k_blocks}, scales, true);

test.AddOptionalInputEdge<uint8_t>();

// g_idx with negative values
std::vector<int32_t> g_idx(K);
for (int64_t i = 0; i < K; i++) {
g_idx[i] = -1;
}
test.AddInput<int32_t>("g_idx", {K}, g_idx, true);

test.AddOptionalInputEdge<float>();

std::vector<float> y_data(M * N, 0.0f);
test.AddOutput<float>("Y", {M, N}, y_data);

test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider,
kOpenVINOExecutionProvider});
}

} // namespace test
} // namespace onnxruntime

Expand Down
Loading