diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h index 25bcb3932795b..072471d3b83a3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h @@ -67,6 +67,19 @@ Status CheckInputs(const T* /*activation*/, // Group_index shall be 1D of K, or K padded to multiple of block_size ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size)); + // Validate group_index values are within valid range [0, k_blocks) + if (group_index != nullptr && group_index->Location().device.Type() == OrtDevice::CPU) { + auto g_idx_data = group_index->template Data(); + auto g_idx_size = static_cast(group_index->Shape().Size()); + for (size_t i = 0; i < g_idx_size; ++i) { + if (g_idx_data[i] < 0 || g_idx_data[i] >= k_blocks) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "group_index value at index ", i, " is ", g_idx_data[i], + ", which is out of valid range [0, ", k_blocks, ")"); + } + } + } + ASSERT_TENSOR_SHAPE(bias, make_shape(n)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu index cbcd4ed2f54a0..5c4501945cecf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu @@ -112,6 +112,8 @@ __global__ void Dequantize4BitsKernelReOrder( const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); for (int i = 0; i < element_per_thread; i++) { int32_t rid = reorder_idx_with_off[i]; + CUDA_KERNEL_ASSERT(rid >= 0 && rid < groups_per_K); + rid = max(0, min(rid, groups_per_K - 1)); // Clamp for release safety T scale = *(scale_data + n_idx * scales_shape_x + rid); uint8_t zp = 8; // Default zero point is 1 << (bits - 1) if (zero_points) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index fbbdc419118cd..b463aa3a6c363 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -880,6 +880,95 @@ TEST(MatMulNBits, Basic_M10_N128_K512) { } #endif +// Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT. +// CUDA EP is excluded from these tests, so no risk of hitting CUDA_KERNEL_ASSERT. +TEST(MatMulNBits, InvalidGIdx_OutOfRange) { + constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 2 + constexpr int64_t blob_size = block_size * QBits / 8; // 8 + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + // A: [M, K] + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + // B: [N, k_blocks, blob_size] + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + // scales: [N, k_blocks] + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + // zero_points: optional (skip) + test.AddOptionalInputEdge(); + + // g_idx with out-of-range values (valid range is [0, k_blocks) = [0, 2)) + std::vector g_idx(K); + for (int64_t i = 0; i < K; i++) { + g_idx[i] = 99999; // way out of range + } + test.AddInput("g_idx", {K}, g_idx, true); + + // bias: optional (skip) + test.AddOptionalInputEdge(); + + // Output placeholder (won't actually be compared since we expect failure) + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider, + kOpenVINOExecutionProvider}); +} + +// Test that negative g_idx values are rejected. +TEST(MatMulNBits, InvalidGIdx_Negative) { + constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + test.AddOptionalInputEdge(); + + // g_idx with negative values + std::vector g_idx(K); + for (int64_t i = 0; i < K; i++) { + g_idx[i] = -1; + } + test.AddInput("g_idx", {K}, g_idx, true); + + test.AddOptionalInputEdge(); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider, + kOpenVINOExecutionProvider}); +} + } // namespace test } // namespace onnxruntime