diff --git a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc index b7a5c5904cf72..9b91215eba91d 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cc @@ -119,9 +119,20 @@ Status GatherBlockQuantized::ComputeInternal(OpKernelContext* ctx) zero_points_ptr = zero_points->Data(); } + // For packed uint8_t with bits < 8, + // after_gather_dim has to be adjusted to match + // the unpacked output dims for correct kernel indexing + int64_t after_gather_dim_unpacked = after_gather_dim; + if constexpr (std::is_same_v) { + uint32_t components = 8 / static_cast(bits_); + if (components > 1) { + after_gather_dim_unpacked *= components; + } + } + GatherBlockQuantizedParam param; param.stream = Stream(ctx); - param.after_gather_dim = after_gather_dim; + param.after_gather_dim = after_gather_dim_unpacked; param.gather_axis_dim = data_shape[gather_axis_]; param.ind_dim = ind_dim; param.bits = bits_; diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index 3bf37ea193245..6fea7a43712c7 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -573,6 +573,80 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints_8Bits) { } #endif +template +void Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits() { + // This test case specific to shared 4bit token_embedding/lm_head use case on CUDA + std::vector data = {-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, + 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1}; + std::vector data_shape = {2, 16}; + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {2.0f, 1.0f}; + std::vector scales_shape = {2, 1}; + // Explicit zero points for each row + std::vector zero_points = {-2, 1}; + + // With explicit zero points: + // Unpacked data (row 1): [0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1] ---add offset 8---> + // Packed (add offset 8): [8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7] + // Gathered scales (row 1): scale = 1.0f, zero_point (row 1): packed: [1] ---add offset 8---> unpacked: [9] + // Expected (CUDA doesn't subtract zero point): [8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7] + std::vector output = {8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + std::vector output_shape = {1, 16}; + + constexpr int64_t gather_axis = 0; + constexpr int64_t quantize_axis = 1; // Last axis (required for CUDA) + constexpr int64_t block_size = 16; + constexpr int64_t bits = 4; + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points, + gather_axis, quantize_axis, block_size, bits, output, output_shape, true); +} + +template +void Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits() { + // This test case specific to shared 8bit token_embedding/lm_head use case on CUDA + std::vector data = {-128, -127, -126, -125, -124, -123, -122, -121, -120, -119, -118, -117, -116, -115, -114, -113, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + std::vector data_shape = {2, 16}; + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f}; + std::vector scales_shape = {2, 1}; + // Explicit zero points + std::vector zero_points = {10, -5}; + + // With explicit zero points: + // Unpacked data (row 1): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] ---add offset 128---> + // Packed (row1): [128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143] + // Zero point unpacked: [-5] ---add offset 128---> packed: [123] + // Dequantization: [(128-123)*2, (129-123)*2, ..., (143-123)*2] = [10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40] + std::vector output = {10.f, 12.f, 14.f, 16.f, 18.f, 20.f, 22.f, 24.f, 26.f, 28.f, 30.f, 32.f, 34.f, 36.f, 38.f, 40.f}; + std::vector output_shape = {1, 16}; + + constexpr int64_t gather_axis = 0; + constexpr int64_t quantize_axis = 1; + constexpr int64_t block_size = 16; + constexpr int64_t bits = 8; + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points, + gather_axis, quantize_axis, block_size, bits, output, output_shape, true); +} + +#ifdef USE_CUDA +TEST(GatherBlockQuantizedOpTest, GatherAxis0_QuantizedAxis1_Uint8_4Bits_WithZeroPoints) { + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits(); + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits(); + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits(); + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_4Bits(); +} + +TEST(GatherBlockQuantizedOpTest, GatherAxis0_QuantizedAxis1_Uint8_8Bits_WithZeroPoints) { + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits(); + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits(); + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits(); + Test_GatherAxis0_QuantizedAxis1_WithZeroPoints_8Bits(); +} +#endif + template void Test_GatherAxis1_WithZeroPoints() { std::vector data = {-8, -7, -6, -5, @@ -665,5 +739,129 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis2) { } #endif +template +void Test_GatherAxis_WithZeroPoints_NoPading() { + std::vector data = { + -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, + -4, -3, -2, -1, -4, -3, -2, -1, -4, -3, -2, -1, -4, -3, -2, -1, + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, + -4, -3, -2, -1, -4, -3, -2, -1, -4, -3, -2, -1, -4, -3, -2, -1}; + + std::vector data_shape = {2, 3, 16}; + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + std::vector zero_points = {-1, 1, 0, 0, 1, -1}; + std::vector output = { + 8, 10, 12, 14, 8, 10, 12, 14, 8, 10, 12, 14, 8, 10, 12, 14, + 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, + -6, -4, -2, 0, -6, -4, -2, 0, -6, -4, -2, 0, -6, -4, -2, 0}; + std::vector output_shape = {1, 3, 16}; + + constexpr int64_t gather_axis = 0; + constexpr int64_t quantize_axis = 2; + constexpr int64_t block_size = 16; + constexpr int64_t bits = 4; + + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points, + gather_axis, quantize_axis, block_size, bits, output, output_shape, true); +} + +#ifdef USE_CUDA +TEST(GatherBlockQuantizedOpTest, GatherAxisWithZeroPointsNoPading) { + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); + Test_GatherAxis_WithZeroPoints_NoPading(); +} +#endif + +template +void Test_GatherAxis_NoPading_4bit() { + std::vector data = { + -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, + -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, -8, -7, -6, -5, + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, + 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4}; + + std::vector data_shape = {2, 3, 16}; + std::vector indices = {0}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + std::vector zero_points = {}; + std::vector output = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, + 0, 2, 4, 6, 0, 2, 4, 6, 0, 2, 4, 6, 0, 2, 4, 6, + 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11}; + std::vector output_shape = {1, 3, 16}; + + constexpr int64_t gather_axis = 0; + constexpr int64_t quantize_axis = 2; + constexpr int64_t block_size = 16; + constexpr int64_t bits = 4; + + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points, + gather_axis, quantize_axis, block_size, bits, output, output_shape, true); +} + +#ifdef USE_CUDA +TEST(GatherBlockQuantizedOpTest, GatherAxisNoPadingUInt8_4Bits) { + Test_GatherAxis_NoPading_4bit(); + Test_GatherAxis_NoPading_4bit(); + Test_GatherAxis_NoPading_4bit(); + Test_GatherAxis_NoPading_4bit(); +} +#endif + +template +void Test_GatherAxis_NoPading_8bit() { + std::vector data = { + 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112, + 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112, + 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112}; + + std::vector data_shape = {2, 3, 16}; + std::vector indices = {0}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + std::vector zero_points = {}; + std::vector output = { + 255, 254, 253, 252, 251, 250, 249, 248, 247, 246, 245, 244, 243, 242, 241, 240, + 510, 508, 506, 504, 502, 500, 498, 496, 494, 492, 490, 488, 486, 484, 482, 480, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143}; + std::vector output_shape = {1, 3, 16}; + + constexpr int64_t gather_axis = 0; + constexpr int64_t quantize_axis = 2; + constexpr int64_t block_size = 16; + constexpr int64_t bits = 8; + + RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points, + gather_axis, quantize_axis, block_size, bits, output, output_shape, true); +} + +#ifdef USE_CUDA +TEST(GatherBlockQuantizedOpTest, GatherAxisNoPadingUInt8) { + Test_GatherAxis_NoPading_8bit(); + Test_GatherAxis_NoPading_8bit(); + Test_GatherAxis_NoPading_8bit(); + Test_GatherAxis_NoPading_8bit(); +} +#endif + } // namespace test } // namespace onnxruntime