From 2ac57eb7a0e51ffb678d6b2cc13a2c75f4be0cc6 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 4 Jun 2025 16:01:16 -0700 Subject: [PATCH 01/88] implement TensorCaster specializations for int4/uint4 --- .../core/providers/cpu/tensor/cast_op.cc | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index d1c280d9886f4..862fd9656d067 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -262,6 +262,58 @@ struct TensorCaster { } }; +// tensor Int4x2 -> float +template <> +struct TensorCaster { + void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Confirm we can unpack the int4 + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The Int4x2 tensor size is invalid for casting to float."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + + // Extract signed high and low nibble + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + + out_data[2 * i] = static_cast(high_nibble); + out_data[2 * i + 1] = static_cast(low_nibble); + } + } +}; + +// tensor UInt4x2 -> float +template <> +struct TensorCaster { + void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Confirm we can unpack the uint4 + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The UInt4x2 tensor size is invalid for casting to float."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + + // Extract unsigned high and low nibble + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + + out_data[2 * i] = static_cast(high_nibble); + out_data[2 * i + 1] = static_cast(low_nibble); + } + } +}; + #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific From e8dc543ad41b0c5fe40410e0751a7c76c6fab966 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 5 Jun 2025 14:38:09 -0700 Subject: [PATCH 02/88] Create unit tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 68 +++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 384adb5916cc1..e255a30b2df79 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "boost/mp11.hpp" @@ -54,15 +55,20 @@ template void TestCastOp(gsl::span input, gsl::span output, - const BaseTester::DimsVariant& dimensions, + const BaseTester::DimsVariant& input_dimensions, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& expected_failure_string = "", int opset = 13, - Saturate saturate = Saturate::None) { + Saturate saturate = Saturate::None, + std::optional output_dimensions_opt = std::nullopt, + std::optional input_element_count_override = std::nullopt, + std::optional output_element_count_override = std::nullop) { OpTester test("Cast", opset); + const BaseTester::DimsVariant& output_dimensions = output_dimensions_opt.value_or(input_dimensions); + test.AddAttribute("to", utils::ToTensorProtoElementType()); - test.AddInput("input", dimensions, input.data(), input.size()); - test.AddOutput("output", dimensions, output.data(), output.size()); + test.AddInput("input", input_dimensions, input.data(), input_element_count_override.value_or(input.size())); + test.AddOutput("output", output_dimensions, output.data(), output_element_count_override.value_or(output.size())); if (saturate != Saturate::None) { test.AddAttribute("saturate", saturate == Saturate::True ? 1 : 0); } @@ -207,6 +213,60 @@ TEST(CastOpTest, ToString) { TestCastOp(gsl::make_span(int_16_input), gsl::make_span(int_string_data), shape); } +TEST(CastOpTest, Int4ToFloat) { + // GIVEN + const std::vector input_shape{2, 2, 1}; + const std::vector int4_input = { + Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 + Int4x2(-3, 4), // lower = -3, upper = 4 + Int4x2(2, -1), + Int4x2(-2, 2) + }; + // There will be twice as many unpacked elements + const std::vector output_shape{2, 2, 2}; + const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, 4.0f, 2.0f, -1.0f, -2.0f, 2.0f}; + + // WHEN, THEN + TestCastOp( + gsl::make_span(int4_input), + gsl::make_span(expected_float_output), + input_shape, + OpTester::ExpectResult::kExpectSuccess, + "", + 13, + Saturate::None, + output_shape, + int4_input.size(), + expected_float_output.size()); +} + +TEST(CastOpTest, UInt4ToFloat) { + // GIVEN + const std::vector input_shape{2, 2, 1}; + const std::vector uint4_input = { + UInt4x2(1, 2), + UInt4x2(3, 4), + UInt4x2(5, 6), + UInt4x2(7, 8) + }; + // There will be twice as many unpacked elements + const std::vector output_shape{2, 2, 2}; + const std::vector expected_float_output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + // WHEN, THEN + TestCastOp( + gsl::make_span(uint4_input), + gsl::make_span(expected_float_output), + input_shape, + OpTester::ExpectResult::kExpectSuccess, + "", + 13, + Saturate::None, + output_shape, + uint4_input.size(), + expected_float_output.size()); +} + #if !defined(DISABLE_FLOAT8_TYPES) template From a4dc230327386efaf6f3bea6a684266bdc9ed779 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 5 Jun 2025 16:49:02 -0700 Subject: [PATCH 03/88] update shape inside tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 64 ++++++------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index e255a30b2df79..0d1544a212307 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include "boost/mp11.hpp" @@ -55,20 +54,15 @@ template void TestCastOp(gsl::span input, gsl::span output, - const BaseTester::DimsVariant& input_dimensions, + const BaseTester::DimsVariant& dimensions, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& expected_failure_string = "", int opset = 13, - Saturate saturate = Saturate::None, - std::optional output_dimensions_opt = std::nullopt, - std::optional input_element_count_override = std::nullopt, - std::optional output_element_count_override = std::nullop) { + Saturate saturate = Saturate::None) { OpTester test("Cast", opset); - const BaseTester::DimsVariant& output_dimensions = output_dimensions_opt.value_or(input_dimensions); - test.AddAttribute("to", utils::ToTensorProtoElementType()); - test.AddInput("input", input_dimensions, input.data(), input_element_count_override.value_or(input.size())); - test.AddOutput("output", output_dimensions, output.data(), output_element_count_override.value_or(output.size())); + test.AddInput("input", dimensions, input.data(), input.size()); + test.AddOutput("output", dimensions, output.data(), input.size()); if (saturate != Saturate::None) { test.AddAttribute("saturate", saturate == Saturate::True ? 1 : 0); } @@ -215,56 +209,34 @@ TEST(CastOpTest, ToString) { TEST(CastOpTest, Int4ToFloat) { // GIVEN - const std::vector input_shape{2, 2, 1}; + const std::vector shape{2, 2, 2}; const std::vector int4_input = { - Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 - Int4x2(-3, 4), // lower = -3, upper = 4 - Int4x2(2, -1), - Int4x2(-2, 2) + Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 + Int4x2(-3, -4), + Int4x2(5, -6), + Int4x2(-8, 7) }; // There will be twice as many unpacked elements - const std::vector output_shape{2, 2, 2}; - const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, 4.0f, 2.0f, -1.0f, -2.0f, 2.0f}; + const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, -4.0f, 5.0f, -6.0f, -8.0f, 7.0f}; // WHEN, THEN - TestCastOp( - gsl::make_span(int4_input), - gsl::make_span(expected_float_output), - input_shape, - OpTester::ExpectResult::kExpectSuccess, - "", - 13, - Saturate::None, - output_shape, - int4_input.size(), - expected_float_output.size()); + TestCastOp(gsl::make_span(int4_input), gsl::make_span(expected_float_output), shape); } TEST(CastOpTest, UInt4ToFloat) { // GIVEN - const std::vector input_shape{2, 2, 1}; + const std::vector shape{2, 2, 2}; const std::vector uint4_input = { - UInt4x2(1, 2), - UInt4x2(3, 4), - UInt4x2(5, 6), - UInt4x2(7, 8) + UInt4x2(0, 1), + UInt4x2(2, 3), + UInt4x2(7, 8), + UInt4x2(14, 15) }; // There will be twice as many unpacked elements - const std::vector output_shape{2, 2, 2}; - const std::vector expected_float_output = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + const std::vector expected_float_output = {0.0f, 1.0f, 2.0f, 3.0f, 7.0f, 8.0f, 14.0f, 15.0f}; // WHEN, THEN - TestCastOp( - gsl::make_span(uint4_input), - gsl::make_span(expected_float_output), - input_shape, - OpTester::ExpectResult::kExpectSuccess, - "", - 13, - Saturate::None, - output_shape, - uint4_input.size(), - expected_float_output.size()); + TestCastOp(gsl::make_span(uint4_input), gsl::make_span(expected_float_output), shape); } #if !defined(DISABLE_FLOAT8_TYPES) From 6d0c03399fbbd99e6c078fdadd7a7130dbaaf035 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 09:38:38 -0700 Subject: [PATCH 04/88] fix bug in tests --- onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 0d1544a212307..4f4e44dfcdb0b 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -62,7 +62,7 @@ void TestCastOp(gsl::span input, OpTester test("Cast", opset); test.AddAttribute("to", utils::ToTensorProtoElementType()); test.AddInput("input", dimensions, input.data(), input.size()); - test.AddOutput("output", dimensions, output.data(), input.size()); + test.AddOutput("output", dimensions, output.data(), output.size()); if (saturate != Saturate::None) { test.AddAttribute("saturate", saturate == Saturate::True ? 1 : 0); } From 497920617216b4cc6df5b8315e4e909bd88d5d80 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 12:48:27 -0700 Subject: [PATCH 05/88] update test names --- .../test/providers/cpu/tensor/cast_op_test.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 4f4e44dfcdb0b..c80828a0dd913 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -207,10 +207,10 @@ TEST(CastOpTest, ToString) { TestCastOp(gsl::make_span(int_16_input), gsl::make_span(int_string_data), shape); } -TEST(CastOpTest, Int4ToFloat) { +TEST(CastOpTest, Int4x2ToFloat) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int4_input = { + const std::vector int4x2_input = { Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 Int4x2(-3, -4), Int4x2(5, -6), @@ -220,13 +220,13 @@ TEST(CastOpTest, Int4ToFloat) { const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, -4.0f, 5.0f, -6.0f, -8.0f, 7.0f}; // WHEN, THEN - TestCastOp(gsl::make_span(int4_input), gsl::make_span(expected_float_output), shape); + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float_output), shape); } -TEST(CastOpTest, UInt4ToFloat) { +TEST(CastOpTest, UInt4x2ToFloat) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint4_input = { + const std::vector uint4x2_input = { UInt4x2(0, 1), UInt4x2(2, 3), UInt4x2(7, 8), @@ -236,7 +236,7 @@ TEST(CastOpTest, UInt4ToFloat) { const std::vector expected_float_output = {0.0f, 1.0f, 2.0f, 3.0f, 7.0f, 8.0f, 14.0f, 15.0f}; // WHEN, THEN - TestCastOp(gsl::make_span(uint4_input), gsl::make_span(expected_float_output), shape); + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float_output), shape); } #if !defined(DISABLE_FLOAT8_TYPES) From 6e34f5b6102261b78c8c7072ca9fa086881f2c1d Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 12:49:18 -0700 Subject: [PATCH 06/88] Add CastTo/FromString and TensorCasterNoSat specializations --- .../core/providers/cpu/tensor/cast_op.cc | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 862fd9656d067..70f78b0e28962 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -124,6 +124,21 @@ CastToString(const SrcType& input, std::string& output) { CastToString(static_cast(input), output); } +inline void CastToString(Int4x2 value, std::string& out) { + // Int4x2 contains two 4-bit signed integers + // Show both values as [first,second] + auto val0 = value.GetElem(0); // First 4-bit value + auto val1 = value.GetElem(1); // Second 4-bit value + out = "[" + std::to_string(static_cast(val0)) + "," + std::to_string(static_cast(val1)) + "]"; +} + +inline void CastToString(UInt4x2 value, std::string& out) { + // UInt4x2 contains two 4-bit unsigned integers + auto val0 = value.GetElem(0); // First 4-bit value + auto val1 = value.GetElem(1); // Second 4-bit value + out = "[" + std::to_string(static_cast(val0)) + "," + std::to_string(static_cast(val1)) + "]"; +} + template typename std::enable_if::value, void>::type CastFromString(const std::string& input, DstType& output) { @@ -148,6 +163,66 @@ CastFromString(const std::string& input, DstType& output) { output = gsl::narrow_cast(std::stoll(input)); } +inline void CastFromString(const std::string& in, Int4x2& out) { + // Parse string format: "[-3,7]" or "-3,7" or just "-3" (single value) + std::string trimmed = in; + + // Remove brackets if present + if (!trimmed.empty() && trimmed.front() == '[') { + trimmed = trimmed.substr(1); + } + if (!trimmed.empty() && trimmed.back() == ']') { + trimmed = trimmed.substr(0, trimmed.length() - 1); + } + + // Find comma separator + size_t comma_pos = trimmed.find(','); + int8_t val0 = 0, val1 = 0; + if (comma_pos != std::string::npos) { + // Two values: "val0,val1" + std::string val0_str = trimmed.substr(0, comma_pos); + std::string val1_str = trimmed.substr(comma_pos + 1); + + val0 = static_cast(std::clamp(std::stoi(val0_str), -8, 7)); + val1 = static_cast(std::clamp(std::stoi(val1_str), -8, 7)); + } else { + // Single value - use for both elements + val0 = val1 = static_cast(std::clamp(std::stoi(trimmed), -8, 7)); + } + + out = Int4x2(val0, val1); +} + +inline void CastFromString(const std::string& in, UInt4x2& out) { + // Parse string format: "[5,12]" or "5,12" or just "5" (single value) + std::string trimmed = in; + + // Remove brackets if present + if (!trimmed.empty() && trimmed.front() == '[') { + trimmed = trimmed.substr(1); + } + if (!trimmed.empty() && trimmed.back() == ']') { + trimmed = trimmed.substr(0, trimmed.length() - 1); + } + + // Find comma separator + size_t comma_pos = trimmed.find(','); + uint8_t val0 = 0, val1 = 0; + if (comma_pos != std::string::npos) { + // Two values: "val0,val1" + std::string val0_str = trimmed.substr(0, comma_pos); + std::string val1_str = trimmed.substr(comma_pos + 1); + + val0 = static_cast(std::clamp(std::stoi(val0_str), 0, 15)); + val1 = static_cast(std::clamp(std::stoi(val1_str), 0, 15)); + } else { + // Single value - use for both elements + val0 = val1 = static_cast(std::clamp(std::stoi(trimmed), 0, 15)); + } + + out = UInt4x2(val0, val1); +} + template #if !defined(DISABLE_FLOAT8_TYPES) typename std::enable_if::value || IsOrtFloat8Type::value, void>::type @@ -234,6 +309,21 @@ struct TensorCasterNoSat { } }; +// TensorCasterNoSat should never be instantiated for Int4x2/UInt4x2 +template +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { + ORT_THROW("Int4x2 should never use TensorCasterNoSat"); + } +}; + +template +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { + ORT_THROW("UInt4x2 should never use TensorCasterNoSat"); + } +}; + // tensor string -> float 8 template struct TensorCasterNoSat { From 61d50d8c1ce0b22a4297434e1c72faf6f50ec5a6 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 12:49:48 -0700 Subject: [PATCH 07/88] fix warning --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 70f78b0e28962..67936558a082b 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -355,7 +355,7 @@ struct TensorCaster { // tensor Int4x2 -> float template <> struct TensorCaster { - void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -381,7 +381,7 @@ struct TensorCaster { // tensor UInt4x2 -> float template <> struct TensorCaster { - void Cast(const OpKernelContext& ctx, const TensorShape& shape, const Tensor& in, Tensor& out) const { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); From 248741150e6d117d16c3677cf6c040a86f003bc7 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 12:52:28 -0700 Subject: [PATCH 08/88] apply lintrunner --- onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index c80828a0dd913..fbb508bc2d034 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -211,11 +211,10 @@ TEST(CastOpTest, Int4x2ToFloat) { // GIVEN const std::vector shape{2, 2, 2}; const std::vector int4x2_input = { - Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 + Int4x2(1, 2), // two 4-bit int elements: lower = 1, upper = 2 Int4x2(-3, -4), Int4x2(5, -6), - Int4x2(-8, 7) - }; + Int4x2(-8, 7)}; // There will be twice as many unpacked elements const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, -4.0f, 5.0f, -6.0f, -8.0f, 7.0f}; @@ -230,8 +229,7 @@ TEST(CastOpTest, UInt4x2ToFloat) { UInt4x2(0, 1), UInt4x2(2, 3), UInt4x2(7, 8), - UInt4x2(14, 15) - }; + UInt4x2(14, 15)}; // There will be twice as many unpacked elements const std::vector expected_float_output = {0.0f, 1.0f, 2.0f, 3.0f, 7.0f, 8.0f, 14.0f, 15.0f}; From e2f6f028f2ac9612df335c0131d96afe33773563 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 12:55:31 -0700 Subject: [PATCH 09/88] output low nibble first --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 67936558a082b..48563abb3129e 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -372,8 +372,8 @@ struct TensorCaster { int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = static_cast(high_nibble); - out_data[2 * i + 1] = static_cast(low_nibble); + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; @@ -398,8 +398,8 @@ struct TensorCaster { uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = static_cast(high_nibble); - out_data[2 * i + 1] = static_cast(low_nibble); + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; From 4ffef5f7ef393ec065fe3807f61534d924e2c348 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 13:08:25 -0700 Subject: [PATCH 10/88] Update to use Rv10 instead of Rv9 (breaks build) --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 48563abb3129e..9453cd5ced02a 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -31,7 +31,7 @@ namespace op_kernel_type_control { // we're using one set of types for all opsets of Cast ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, - element_type_lists::AllIRv9); + element_type_lists::AllIRv10); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, @@ -39,7 +39,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, - element_type_lists::AllIRv9); + element_type_lists::AllIRv10); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, From e0079c0a70d0f6c6cc1a942e457d78c08ea19da0 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 16:39:16 -0700 Subject: [PATCH 11/88] more than 1 partial specialization --- .../core/providers/cpu/tensor/cast_op.cc | 419 ++++++++++++++++-- 1 file changed, 370 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 9453cd5ced02a..7ed374da44c5e 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -252,6 +252,213 @@ template <> struct EigenCastType { using type = Eigen::bfloat16; }; + +// Helper struct for converting from Int4x2/UInt4x2 elements to any destination type +namespace { +template +struct Int4ElementConverter { + static DstType Convert(int8_t val) { + // Default implementation for most numeric types + return static_cast(val); + } +}; + +template <> +struct Int4ElementConverter { + static MLFloat16 Convert(int8_t val) { + return MLFloat16(static_cast(val)); + } +}; + +template <> +struct Int4ElementConverter { + static BFloat16 Convert(int8_t val) { + return BFloat16(static_cast(val)); + } +}; + +template <> +struct Int4ElementConverter { + static std::string Convert(int8_t val) { + return std::to_string(static_cast(val)); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template <> +struct Int4ElementConverter { + static Float8E4M3FN Convert(int8_t val) { + return Float8E4M3FN(static_cast(val), true); + } +}; + +template <> +struct Int4ElementConverter { + static Float8E4M3FNUZ Convert(int8_t val) { + return Float8E4M3FNUZ(static_cast(val), true); + } +}; + +template <> +struct Int4ElementConverter { + static Float8E5M2 Convert(int8_t val) { + return Float8E5M2(static_cast(val), true); + } +}; + +template <> +struct Int4ElementConverter { + static Float8E5M2FNUZ Convert(int8_t val) { + return Float8E5M2FNUZ(static_cast(val), true); + } +}; + +#endif + +// For unsigned int4 elements, we use the same converter but with uint8_t input +template +static DstType ConvertUInt4Element(uint8_t val) { + return Int4ElementConverter::Convert(static_cast(val)); +} + +// Helper struct for converting from any type to Int4/UInt4 elements +template +struct ToInt4ElementConverter { + // Default implementation for most numeric types + static int8_t ConvertToInt4(const SrcType& val) { + int8_t result = static_cast(val); + // Clamp to int4 range (-8 to 7) + return std::clamp(result, static_cast(-8), static_cast(7)); + } + + static uint8_t ConvertToUInt4(const SrcType& val) { + uint8_t result = static_cast(val); + // Clamp to uint4 range (0 to 15) + return std::min(result, static_cast(15)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const float& val) { + int8_t result = static_cast(std::roundf(val)); + return std::clamp(result, static_cast(-8), static_cast(7)); + } + + static uint8_t ConvertToUInt4(const float& val) { + uint8_t result = static_cast(std::max(0.0f, std::roundf(val))); + return std::min(result, static_cast(15)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const double& val) { + int8_t result = static_cast(std::round(val)); + return std::clamp(result, static_cast(-8), static_cast(7)); + } + + static uint8_t ConvertToUInt4(const double& val) { + uint8_t result = static_cast(std::max(0.0, std::round(val))); + return std::min(result, static_cast(15)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const BFloat16& val) { + return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); + } + + static uint8_t ConvertToUInt4(const BFloat16& val) { + return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const MLFloat16& val) { + return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); + } + + static uint8_t ConvertToUInt4(const MLFloat16& val) { + return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const Float8E4M3FN& val) { + return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); + } + + static uint8_t ConvertToUInt4(const Float8E4M3FN& val) { + return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const Float8E4M3FNUZ& val) { + return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); + } + + static uint8_t ConvertToUInt4(const Float8E4M3FNUZ& val) { + return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const Float8E5M2& val) { + return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); + } + + static uint8_t ConvertToUInt4(const Float8E5M2& val) { + return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + } +}; + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const Float8E5M2FNUZ& val) { + return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); + } + + static uint8_t ConvertToUInt4(const Float8E5M2FNUZ& val) { + return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + } +}; + +#endif + +template <> +struct ToInt4ElementConverter { + static int8_t ConvertToInt4(const std::string& val) { + int result; + try { + result = std::stoi(val); + } catch (...) { + result = 0; + } + return std::clamp(result, -8, 7); + } + + static uint8_t ConvertToUInt4(const std::string& val) { + unsigned int result; + try { + result = std::stoul(val); + } catch (...) { + result = 0; + } + return std::min(result, 15u); + } +}; +} // anonymous namespace + // generic tensor X -> Y template struct TensorCaster { @@ -294,53 +501,6 @@ struct TensorCaster { } }; -#if !defined(DISABLE_FLOAT8_TYPES) - -// tensor X -> float 8 -template -struct TensorCasterNoSat { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const std::ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - for (std::ptrdiff_t i = 0; i < shape_size; ++i) { - out_data[i] = DstType(static_cast(in_data[i]), false); - } - } -}; - -// TensorCasterNoSat should never be instantiated for Int4x2/UInt4x2 -template -struct TensorCasterNoSat { - void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { - ORT_THROW("Int4x2 should never use TensorCasterNoSat"); - } -}; - -template -struct TensorCasterNoSat { - void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { - ORT_THROW("UInt4x2 should never use TensorCasterNoSat"); - } -}; - -// tensor string -> float 8 -template -struct TensorCasterNoSat { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const std::ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - float float_value; - for (std::ptrdiff_t i = 0; i < shape_size; ++i) { - CastFromString(in_data[i], float_value); - out_data[i] = DstType(float_value, false); - } - } -}; - -#endif - // tensor MLFloat16 -> float template <> struct TensorCaster { @@ -354,7 +514,7 @@ struct TensorCaster { // tensor Int4x2 -> float template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -380,7 +540,7 @@ struct TensorCaster { // tensor UInt4x2 -> float template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -404,6 +564,120 @@ struct TensorCaster { } }; +// Specialization for Int4x2 to any non-float type +template +struct TensorCaster::value>::type> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Confirm we can unpack the int4 + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + + // Extract signed high and low nibble + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + + // Low nibble first, then high nibble + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +// Specialization for UInt4x2 to any non-float type +template +struct TensorCaster::value>::type> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Confirm we can unpack the uint4 + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + + // Extract unsigned high and low nibble + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + + // Low nibble first, then high nibble + out_data[2 * i] = ConvertUInt4Element(low_nibble); + out_data[2 * i + 1] = ConvertUInt4Element(high_nibble); + } + } +}; + +// Tensor any type to Int4x2 +template +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Confirm we can pack to int4x2 + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); + + // Process pairs of elements, packing them into Int4x2 + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); + } + + // Handle odd number of elements by padding with 0 + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); + } + } +}; + +// Tensor any type to UInt4x2 +template +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Confirm we can pack to uint4x2 + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); + + // Process pairs of elements, packing them into UInt4x2 + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + // Handle odd number of elements by padding with 0 + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific @@ -441,6 +715,53 @@ struct TensorCaster { }; #endif +#if !defined(DISABLE_FLOAT8_TYPES) + +// tensor X -> float 8 +template +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + out_data[i] = DstType(static_cast(in_data[i]), false); + } + } +}; + +// TensorCasterNoSat should never be instantiated for Int4x2/UInt4x2 +template +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { + ORT_THROW("Int4x2 should never use TensorCasterNoSat"); + } +}; + +template +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { + ORT_THROW("UInt4x2 should never use TensorCasterNoSat"); + } +}; + +// tensor string -> float 8 +template +struct TensorCasterNoSat { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const std::ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + float float_value; + for (std::ptrdiff_t i = 0; i < shape_size; ++i) { + CastFromString(in_data[i], float_value); + out_data[i] = DstType(float_value, false); + } + } +}; + +#endif + class Cast final : public OpKernel { public: Cast(const OpKernelInfo& info) : OpKernel(info) { From 6157a9df1d31f12d98f5e02b4fce3d710e6c566b Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 16:50:44 -0700 Subject: [PATCH 12/88] cannot convert OldType to NewType --- .../core/providers/cpu/tensor/cast_op.cc | 302 +++++++++++------- 1 file changed, 184 insertions(+), 118 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 7ed374da44c5e..de436ce4391fe 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -313,15 +313,8 @@ struct Int4ElementConverter { return Float8E5M2FNUZ(static_cast(val), true); } }; - #endif -// For unsigned int4 elements, we use the same converter but with uint8_t input -template -static DstType ConvertUInt4Element(uint8_t val) { - return Int4ElementConverter::Convert(static_cast(val)); -} - // Helper struct for converting from any type to Int4/UInt4 elements template struct ToInt4ElementConverter { @@ -459,6 +452,109 @@ struct ToInt4ElementConverter { }; } // anonymous namespace +#define DEFINE_INT4X2_TO_TYPE_CASTER(TYPE) \ + template <> \ + struct TensorCaster { \ + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ + const auto* in_data = in.Data(); \ + auto* out_data = out.MutableData(); \ + \ + const size_t shape_size = narrow(shape.Size()); \ + const size_t in_shape_size = narrow(in.Shape().Size()); \ + ORT_ENFORCE(in_shape_size * 2 == shape_size, \ + "The Int4x2 tensor size is invalid for casting."); \ + \ + for (size_t i = 0; i < in_shape_size; ++i) { \ + const uint8_t packed = static_cast(in_data[i].bits_); \ + \ + int8_t high_nibble = static_cast(packed) >> 4; \ + int8_t low_nibble = static_cast(packed << 4) >> 4; \ + \ + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); \ + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); \ + } \ + } \ + }; + +// Define a macro to generate full specializations for UInt4x2 to specific types +#define DEFINE_UINT4X2_TO_TYPE_CASTER(TYPE) \ + template <> \ + struct TensorCaster { \ + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ + const auto* in_data = in.Data(); \ + auto* out_data = out.MutableData(); \ + \ + const size_t shape_size = narrow(shape.Size()); \ + const size_t in_shape_size = narrow(in.Shape().Size()); \ + ORT_ENFORCE(in_shape_size * 2 == shape_size, \ + "The UInt4x2 tensor size is invalid for casting."); \ + \ + for (size_t i = 0; i < in_shape_size; ++i) { \ + const uint8_t packed = static_cast(in_data[i].bits_); \ + \ + uint8_t high_nibble = (packed >> 4) & 0x0F; \ + uint8_t low_nibble = packed & 0x0F; \ + \ + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); \ + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); \ + } \ + } \ + }; + +// Define a macro to generate full specializations for type to Int4x2 +#define DEFINE_TYPE_TO_INT4X2_CASTER(TYPE) \ + template <> \ + struct TensorCaster { \ + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ + const auto* in_data = in.Data(); \ + auto* out_data = out.MutableData(); \ + \ + const size_t in_shape_size = narrow(shape.Size()); \ + const size_t out_shape_size = narrow(out.Shape().Size()); \ + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, \ + "The output Int4x2 tensor size is invalid for casting from ", typeid(TYPE).name()); \ + \ + size_t i = 0; \ + for (; i < in_shape_size - 1; i += 2) { \ + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); \ + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); \ + out_data[i / 2] = Int4x2(low_val, high_val); \ + } \ + \ + if (i < in_shape_size) { \ + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); \ + out_data[i / 2] = Int4x2(low_val, 0); \ + } \ + } \ + }; + +// Define a macro to generate full specializations for type to UInt4x2 +#define DEFINE_TYPE_TO_UINT4X2_CASTER(TYPE) \ + template <> \ + struct TensorCaster { \ + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ + const auto* in_data = in.Data(); \ + auto* out_data = out.MutableData(); \ + \ + const size_t in_shape_size = narrow(shape.Size()); \ + const size_t out_shape_size = narrow(out.Shape().Size()); \ + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, \ + "The output UInt4x2 tensor size is invalid for casting from ", typeid(TYPE).name()); \ + \ + size_t i = 0; \ + for (; i < in_shape_size - 1; i += 2) { \ + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); \ + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); \ + out_data[i / 2] = UInt4x2(low_val, high_val); \ + } \ + \ + if (i < in_shape_size) { \ + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); \ + out_data[i / 2] = UInt4x2(low_val, 0); \ + } \ + } \ + }; + // generic tensor X -> Y template struct TensorCaster { @@ -563,120 +659,90 @@ struct TensorCaster { } } }; +// Generate all the required specializations +// Int4x2/UInt4x2 to various types +DEFINE_INT4X2_TO_TYPE_CASTER(int8_t) +DEFINE_INT4X2_TO_TYPE_CASTER(uint8_t) +DEFINE_INT4X2_TO_TYPE_CASTER(int16_t) +DEFINE_INT4X2_TO_TYPE_CASTER(uint16_t) +DEFINE_INT4X2_TO_TYPE_CASTER(int32_t) +DEFINE_INT4X2_TO_TYPE_CASTER(uint32_t) +DEFINE_INT4X2_TO_TYPE_CASTER(int64_t) +DEFINE_INT4X2_TO_TYPE_CASTER(uint64_t) +DEFINE_INT4X2_TO_TYPE_CASTER(double) +DEFINE_INT4X2_TO_TYPE_CASTER(bool) +DEFINE_INT4X2_TO_TYPE_CASTER(MLFloat16) +DEFINE_INT4X2_TO_TYPE_CASTER(BFloat16) +DEFINE_INT4X2_TO_TYPE_CASTER(std::string) +#if !defined(DISABLE_FLOAT8_TYPES) +DEFINE_INT4X2_TO_TYPE_CASTER(Float8E4M3FN) +DEFINE_INT4X2_TO_TYPE_CASTER(Float8E4M3FNUZ) +DEFINE_INT4X2_TO_TYPE_CASTER(Float8E5M2) +DEFINE_INT4X2_TO_TYPE_CASTER(Float8E5M2FNUZ) +#endif -// Specialization for Int4x2 to any non-float type -template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can unpack the int4 - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - - // Extract signed high and low nibble - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - - // Low nibble first, then high nibble - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -// Specialization for UInt4x2 to any non-float type -template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can unpack the uint4 - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - - // Extract unsigned high and low nibble - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - - // Low nibble first, then high nibble - out_data[2 * i] = ConvertUInt4Element(low_nibble); - out_data[2 * i + 1] = ConvertUInt4Element(high_nibble); - } - } -}; - -// Tensor any type to Int4x2 -template -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can pack to int4x2 - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - - // Process pairs of elements, packing them into Int4x2 - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); - } - - // Handle odd number of elements by padding with 0 - if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); - } - } -}; +DEFINE_UINT4X2_TO_TYPE_CASTER(int8_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(uint8_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(int16_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(uint16_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(int32_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(uint32_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(int64_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(uint64_t) +DEFINE_UINT4X2_TO_TYPE_CASTER(double) +DEFINE_UINT4X2_TO_TYPE_CASTER(bool) +DEFINE_UINT4X2_TO_TYPE_CASTER(MLFloat16) +DEFINE_UINT4X2_TO_TYPE_CASTER(BFloat16) +DEFINE_UINT4X2_TO_TYPE_CASTER(std::string) +#if !defined(DISABLE_FLOAT8_TYPES) +DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E4M3FN) +DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E4M3FNUZ) +DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E5M2) +DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E5M2FNUZ) +#endif -// Tensor any type to UInt4x2 -template -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can pack to uint4x2 - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - - // Process pairs of elements, packing them into UInt4x2 - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); - } +// Various types to Int4x2/UInt4x2 +DEFINE_TYPE_TO_INT4X2_CASTER(int8_t) +DEFINE_TYPE_TO_INT4X2_CASTER(uint8_t) +DEFINE_TYPE_TO_INT4X2_CASTER(int16_t) +DEFINE_TYPE_TO_INT4X2_CASTER(uint16_t) +DEFINE_TYPE_TO_INT4X2_CASTER(int32_t) +DEFINE_TYPE_TO_INT4X2_CASTER(uint32_t) +DEFINE_TYPE_TO_INT4X2_CASTER(int64_t) +DEFINE_TYPE_TO_INT4X2_CASTER(uint64_t) +DEFINE_TYPE_TO_INT4X2_CASTER(float) +DEFINE_TYPE_TO_INT4X2_CASTER(double) +DEFINE_TYPE_TO_INT4X2_CASTER(bool) +DEFINE_TYPE_TO_INT4X2_CASTER(MLFloat16) +DEFINE_TYPE_TO_INT4X2_CASTER(BFloat16) +DEFINE_TYPE_TO_INT4X2_CASTER(std::string) +#if !defined(DISABLE_FLOAT8_TYPES) +DEFINE_TYPE_TO_INT4X2_CASTER(Float8E4M3FN) +DEFINE_TYPE_TO_INT4X2_CASTER(Float8E4M3FNUZ) +DEFINE_TYPE_TO_INT4X2_CASTER(Float8E5M2) +DEFINE_TYPE_TO_INT4X2_CASTER(Float8E5M2FNUZ) +#endif - // Handle odd number of elements by padding with 0 - if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); - } - } -}; +DEFINE_TYPE_TO_UINT4X2_CASTER(int8_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(uint8_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(int16_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(uint16_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(int32_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(uint32_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(int64_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(uint64_t) +DEFINE_TYPE_TO_UINT4X2_CASTER(float) +DEFINE_TYPE_TO_UINT4X2_CASTER(double) +DEFINE_TYPE_TO_UINT4X2_CASTER(bool) +DEFINE_TYPE_TO_UINT4X2_CASTER(MLFloat16) +DEFINE_TYPE_TO_UINT4X2_CASTER(BFloat16) +DEFINE_TYPE_TO_UINT4X2_CASTER(std::string) +#if !defined(DISABLE_FLOAT8_TYPES) +DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E4M3FN) +DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E4M3FNUZ) +DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E5M2) +DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E5M2FNUZ) +#endif #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific From 8cd523707439b14e43cb0df9a2e0ccf610f6f9ae Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 17:03:15 -0700 Subject: [PATCH 13/88] remove macros, expand specializations --- .../core/providers/cpu/tensor/cast_op.cc | 991 ++++++++++++++++-- 1 file changed, 914 insertions(+), 77 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index de436ce4391fe..40c0c2e9e4fb6 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -659,91 +659,928 @@ struct TensorCaster { } } }; -// Generate all the required specializations -// Int4x2/UInt4x2 to various types -DEFINE_INT4X2_TO_TYPE_CASTER(int8_t) -DEFINE_INT4X2_TO_TYPE_CASTER(uint8_t) -DEFINE_INT4X2_TO_TYPE_CASTER(int16_t) -DEFINE_INT4X2_TO_TYPE_CASTER(uint16_t) -DEFINE_INT4X2_TO_TYPE_CASTER(int32_t) -DEFINE_INT4X2_TO_TYPE_CASTER(uint32_t) -DEFINE_INT4X2_TO_TYPE_CASTER(int64_t) -DEFINE_INT4X2_TO_TYPE_CASTER(uint64_t) -DEFINE_INT4X2_TO_TYPE_CASTER(double) -DEFINE_INT4X2_TO_TYPE_CASTER(bool) -DEFINE_INT4X2_TO_TYPE_CASTER(MLFloat16) -DEFINE_INT4X2_TO_TYPE_CASTER(BFloat16) -DEFINE_INT4X2_TO_TYPE_CASTER(std::string) + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + #if !defined(DISABLE_FLOAT8_TYPES) -DEFINE_INT4X2_TO_TYPE_CASTER(Float8E4M3FN) -DEFINE_INT4X2_TO_TYPE_CASTER(Float8E4M3FNUZ) -DEFINE_INT4X2_TO_TYPE_CASTER(Float8E5M2) -DEFINE_INT4X2_TO_TYPE_CASTER(Float8E5M2FNUZ) + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + #endif -DEFINE_UINT4X2_TO_TYPE_CASTER(int8_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(uint8_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(int16_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(uint16_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(int32_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(uint32_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(int64_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(uint64_t) -DEFINE_UINT4X2_TO_TYPE_CASTER(double) -DEFINE_UINT4X2_TO_TYPE_CASTER(bool) -DEFINE_UINT4X2_TO_TYPE_CASTER(MLFloat16) -DEFINE_UINT4X2_TO_TYPE_CASTER(BFloat16) -DEFINE_UINT4X2_TO_TYPE_CASTER(std::string) -#if !defined(DISABLE_FLOAT8_TYPES) -DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E4M3FN) -DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E4M3FNUZ) -DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E5M2) -DEFINE_UINT4X2_TO_TYPE_CASTER(Float8E5M2FNUZ) -#endif +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); -// Various types to Int4x2/UInt4x2 -DEFINE_TYPE_TO_INT4X2_CASTER(int8_t) -DEFINE_TYPE_TO_INT4X2_CASTER(uint8_t) -DEFINE_TYPE_TO_INT4X2_CASTER(int16_t) -DEFINE_TYPE_TO_INT4X2_CASTER(uint16_t) -DEFINE_TYPE_TO_INT4X2_CASTER(int32_t) -DEFINE_TYPE_TO_INT4X2_CASTER(uint32_t) -DEFINE_TYPE_TO_INT4X2_CASTER(int64_t) -DEFINE_TYPE_TO_INT4X2_CASTER(uint64_t) -DEFINE_TYPE_TO_INT4X2_CASTER(float) -DEFINE_TYPE_TO_INT4X2_CASTER(double) -DEFINE_TYPE_TO_INT4X2_CASTER(bool) -DEFINE_TYPE_TO_INT4X2_CASTER(MLFloat16) -DEFINE_TYPE_TO_INT4X2_CASTER(BFloat16) -DEFINE_TYPE_TO_INT4X2_CASTER(std::string) -#if !defined(DISABLE_FLOAT8_TYPES) -DEFINE_TYPE_TO_INT4X2_CASTER(Float8E4M3FN) -DEFINE_TYPE_TO_INT4X2_CASTER(Float8E4M3FNUZ) -DEFINE_TYPE_TO_INT4X2_CASTER(Float8E5M2) -DEFINE_TYPE_TO_INT4X2_CASTER(Float8E5M2FNUZ) -#endif + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; -DEFINE_TYPE_TO_UINT4X2_CASTER(int8_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(uint8_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(int16_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(uint16_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(int32_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(uint32_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(int64_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(uint64_t) -DEFINE_TYPE_TO_UINT4X2_CASTER(float) -DEFINE_TYPE_TO_UINT4X2_CASTER(double) -DEFINE_TYPE_TO_UINT4X2_CASTER(bool) -DEFINE_TYPE_TO_UINT4X2_CASTER(MLFloat16) -DEFINE_TYPE_TO_UINT4X2_CASTER(BFloat16) -DEFINE_TYPE_TO_UINT4X2_CASTER(std::string) #if !defined(DISABLE_FLOAT8_TYPES) -DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E4M3FN) -DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E4M3FNUZ) -DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E5M2) -DEFINE_TYPE_TO_UINT4X2_CASTER(Float8E5M2FNUZ) +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t shape_size = narrow(shape.Size()); + const size_t in_shape_size = narrow(in.Shape().Size()); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + } + } +}; #endif +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from float."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); + } + + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from float."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from int32_t."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); + } + + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from int8_t."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); + } + + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from BFloat16."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); + } + + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from MLFloat16."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); + } + + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from int32_t."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from uint8_t."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from uint16_t."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from uint32_t."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output UInt4x2 tensor size is invalid for casting from double."); + + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); + } + + if (i < in_shape_size) { + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); + } + } +}; + #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific From 5e23ee2c958ac6239823b5641d953db274c49ba2 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 6 Jun 2025 17:11:53 -0700 Subject: [PATCH 14/88] merge together similar specializations --- .../core/providers/cpu/tensor/cast_op.cc | 887 +++++------------- 1 file changed, 244 insertions(+), 643 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 40c0c2e9e4fb6..52f381f222bb8 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -450,110 +450,68 @@ struct ToInt4ElementConverter { return std::min(result, 15u); } }; -} // anonymous namespace -#define DEFINE_INT4X2_TO_TYPE_CASTER(TYPE) \ - template <> \ - struct TensorCaster { \ - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ - const auto* in_data = in.Data(); \ - auto* out_data = out.MutableData(); \ - \ - const size_t shape_size = narrow(shape.Size()); \ - const size_t in_shape_size = narrow(in.Shape().Size()); \ - ORT_ENFORCE(in_shape_size * 2 == shape_size, \ - "The Int4x2 tensor size is invalid for casting."); \ - \ - for (size_t i = 0; i < in_shape_size; ++i) { \ - const uint8_t packed = static_cast(in_data[i].bits_); \ - \ - int8_t high_nibble = static_cast(packed) >> 4; \ - int8_t low_nibble = static_cast(packed << 4) >> 4; \ - \ - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); \ - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); \ - } \ - } \ - }; - -// Define a macro to generate full specializations for UInt4x2 to specific types -#define DEFINE_UINT4X2_TO_TYPE_CASTER(TYPE) \ - template <> \ - struct TensorCaster { \ - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ - const auto* in_data = in.Data(); \ - auto* out_data = out.MutableData(); \ - \ - const size_t shape_size = narrow(shape.Size()); \ - const size_t in_shape_size = narrow(in.Shape().Size()); \ - ORT_ENFORCE(in_shape_size * 2 == shape_size, \ - "The UInt4x2 tensor size is invalid for casting."); \ - \ - for (size_t i = 0; i < in_shape_size; ++i) { \ - const uint8_t packed = static_cast(in_data[i].bits_); \ - \ - uint8_t high_nibble = (packed >> 4) & 0x0F; \ - uint8_t low_nibble = packed & 0x0F; \ - \ - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); \ - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); \ - } \ - } \ - }; - -// Define a macro to generate full specializations for type to Int4x2 -#define DEFINE_TYPE_TO_INT4X2_CASTER(TYPE) \ - template <> \ - struct TensorCaster { \ - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ - const auto* in_data = in.Data(); \ - auto* out_data = out.MutableData(); \ - \ - const size_t in_shape_size = narrow(shape.Size()); \ - const size_t out_shape_size = narrow(out.Shape().Size()); \ - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, \ - "The output Int4x2 tensor size is invalid for casting from ", typeid(TYPE).name()); \ - \ - size_t i = 0; \ - for (; i < in_shape_size - 1; i += 2) { \ - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); \ - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); \ - out_data[i / 2] = Int4x2(low_val, high_val); \ - } \ - \ - if (i < in_shape_size) { \ - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); \ - out_data[i / 2] = Int4x2(low_val, 0); \ - } \ - } \ - }; - -// Define a macro to generate full specializations for type to UInt4x2 -#define DEFINE_TYPE_TO_UINT4X2_CASTER(TYPE) \ - template <> \ - struct TensorCaster { \ - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { \ - const auto* in_data = in.Data(); \ - auto* out_data = out.MutableData(); \ - \ - const size_t in_shape_size = narrow(shape.Size()); \ - const size_t out_shape_size = narrow(out.Shape().Size()); \ - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, \ - "The output UInt4x2 tensor size is invalid for casting from ", typeid(TYPE).name()); \ - \ - size_t i = 0; \ - for (; i < in_shape_size - 1; i += 2) { \ - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); \ - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); \ - out_data[i / 2] = UInt4x2(low_val, high_val); \ - } \ - \ - if (i < in_shape_size) { \ - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); \ - out_data[i / 2] = UInt4x2(low_val, 0); \ - } \ - } \ - }; +// Check if a type is one of the integer types +template +struct is_standard_integer { + static constexpr bool value = + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value; +}; + +// Check if a type is one of the floating point types +template +struct is_standard_float { + static constexpr bool value = + std::is_same::value || + std::is_same::value; +}; + +// Check if a type is one of the half-precision float types +template +struct is_half_float { + static constexpr bool value = + std::is_same::value || + std::is_same::value; +}; + +#if !defined(DISABLE_FLOAT8_TYPES) +// Check if a type is one of the 8-bit float types +template +struct is_float8_type { + static constexpr bool value = + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value; +}; +#endif + +// Enable if for standard integers +template +using enable_if_standard_integer = typename std::enable_if::value, void>::type; + +// Enable if for standard floats +template +using enable_if_standard_float = typename std::enable_if::value, void>::type; + +// Enable if for half floats +template +using enable_if_half_float = typename std::enable_if::value, void>::type; + +#if !defined(DISABLE_FLOAT8_TYPES) +// Enable if for float8 types +template +using enable_if_float8_type = typename std::enable_if::value, void>::type; +#endif + +} // anonymous namespace // generic tensor X -> Y template @@ -608,143 +566,13 @@ struct TensorCaster { } }; -// tensor Int4x2 -> float -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can unpack the int4 - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The Int4x2 tensor size is invalid for casting to float."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - - // Extract signed high and low nibble - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); - } - } -}; - -// tensor UInt4x2 -> float -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can unpack the uint4 - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The UInt4x2 tensor size is invalid for casting to float."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - - // Extract unsigned high and low nibble - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { +template +struct TensorCaster::value && + !std::is_same::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -754,68 +582,33 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The Int4x2 tensor size is invalid for casting to float."); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); + // Extract signed high and low nibble int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + // Low nibble first, then high nibble + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; @@ -834,17 +627,18 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -854,17 +648,19 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; -template <> -struct TensorCaster { +#if !defined(DISABLE_FLOAT8_TYPES) +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -874,17 +670,18 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; +#endif template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -894,8 +691,8 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = low_nibble != 0; + out_data[2 * i + 1] = high_nibble != 0; } } }; @@ -920,175 +717,39 @@ struct TensorCaster { } }; -#if !defined(DISABLE_FLOAT8_TYPES) - template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -#endif - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); + // Confirm we can unpack the uint4 const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The UInt4x2 tensor size is invalid for casting to float."); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); + // Extract unsigned high and low nibble uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + // Low nibble first, then high nibble + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value && + !std::is_same::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -1098,17 +759,17 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -1118,17 +779,18 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -1138,17 +800,19 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; -template <> -struct TensorCaster { +#if !defined(DISABLE_FLOAT8_TYPES) +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -1158,11 +822,12 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; +#endif template <> struct TensorCaster { @@ -1178,58 +843,17 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = low_nibble != 0; + out_data[2 * i + 1] = high_nibble != 0; } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -#if !defined(DISABLE_FLOAT8_TYPES) -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -1239,347 +863,324 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); + out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, + "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + size_t i = 0; + for (; i < in_shape_size - 1; i += 2) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + if (i < in_shape_size) { + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); } } }; -#endif -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from float."); + "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); out_data[i / 2] = Int4x2(low_val, high_val); } if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); out_data[i / 2] = Int4x2(low_val, 0); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from float."); + "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i / 2] = Int4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i / 2] = Int4x2(low_val, 0); } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from int32_t."); + "The output Int4x2 tensor size is invalid for casting from bool."); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = in_data[i] ? 1 : 0; + int8_t high_val = in_data[i + 1] ? 1 : 0; out_data[i / 2] = Int4x2(low_val, high_val); } if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = in_data[i] ? 1 : 0; out_data[i / 2] = Int4x2(low_val, 0); } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from int8_t."); + "The output Int4x2 tensor size is invalid for casting from string."); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); out_data[i / 2] = Int4x2(low_val, high_val); } if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); out_data[i / 2] = Int4x2(low_val, 0); } } }; -template <> -struct TensorCaster { +#if !defined(DISABLE_FLOAT8_TYPES) +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from BFloat16."); + "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); out_data[i / 2] = Int4x2(low_val, high_val); } if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); out_data[i / 2] = Int4x2(low_val, 0); } } }; +#endif -template <> -struct TensorCaster { + +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from MLFloat16."); + "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i / 2] = UInt4x2(low_val, 0); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from int32_t."); + "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i / 2] = UInt4x2(low_val, 0); } } }; -template <> -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from uint8_t."); + "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i / 2] = UInt4x2(low_val, 0); } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from uint16_t."); + "The output UInt4x2 tensor size is invalid for casting from bool."); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = in_data[i] ? 1 : 0; + uint8_t high_val = in_data[i + 1] ? 1 : 0; out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = in_data[i] ? 1 : 0; out_data[i / 2] = UInt4x2(low_val, 0); } } }; template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from uint32_t."); + "The output UInt4x2 tensor size is invalid for casting from string."); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i / 2] = UInt4x2(low_val, 0); } } }; -template <> -struct TensorCaster { +#if !defined(DISABLE_FLOAT8_TYPES) +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from double."); + "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i / 2] = UInt4x2(low_val, 0); } } }; +#endif #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific From 9d50386e1d441e2b35753fc81edec4fedbaf960b Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 09:11:02 -0700 Subject: [PATCH 15/88] remove unused aliases --- .../core/providers/cpu/tensor/cast_op.cc | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 52f381f222bb8..44f51c9451488 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -493,24 +493,6 @@ struct is_float8_type { }; #endif -// Enable if for standard integers -template -using enable_if_standard_integer = typename std::enable_if::value, void>::type; - -// Enable if for standard floats -template -using enable_if_standard_float = typename std::enable_if::value, void>::type; - -// Enable if for half floats -template -using enable_if_half_float = typename std::enable_if::value, void>::type; - -#if !defined(DISABLE_FLOAT8_TYPES) -// Enable if for float8 types -template -using enable_if_float8_type = typename std::enable_if::value, void>::type; -#endif - } // anonymous namespace // generic tensor X -> Y From 8474ca73a1408b2965cac3929282042855a197dd Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 09:38:28 -0700 Subject: [PATCH 16/88] clean up template specializations --- .../core/providers/cpu/tensor/cast_op.cc | 85 +++++-------------- 1 file changed, 22 insertions(+), 63 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 44f51c9451488..7412797d8ae87 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -550,8 +550,7 @@ struct TensorCaster { template struct TensorCaster::value && - !std::is_same::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -570,16 +569,16 @@ struct TensorCaster -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The Int4x2 tensor size is invalid for casting to float."); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); @@ -589,28 +588,8 @@ struct TensorCaster { int8_t low_nibble = static_cast(packed << 4) >> 4; // Low nibble first, then high nibble - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; @@ -699,36 +678,9 @@ struct TensorCaster { } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Confirm we can unpack the uint4 - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The UInt4x2 tensor size is invalid for casting to float."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - - // Extract unsigned high and low nibble - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - - // Low nibble first, then high nibble - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); - } - } -}; - template struct TensorCaster::value && - !std::is_same::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -747,22 +699,29 @@ struct TensorCaster -struct TensorCaster { +template +struct TensorCaster::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); + // Confirm we can unpack the uint4 const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + ORT_ENFORCE(in_shape_size * 2 == shape_size, + "The UInt4x2 tensor size is invalid for casting to float."); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); + + // Extract unsigned high and low nibble uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); + + // Low nibble first, then high nibble + out_data[2 * i] = static_cast(low_nibble); + out_data[2 * i + 1] = static_cast(high_nibble); } } }; From 1ddcb968ae252f19982c3b0ce5e388d921074269 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 10:07:34 -0700 Subject: [PATCH 17/88] reuse existing aliases, refactor for consistency --- .../core/providers/cpu/tensor/cast_op.cc | 94 +++++++------------ 1 file changed, 36 insertions(+), 58 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 7412797d8ae87..37e93dfdf65a0 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -60,6 +60,26 @@ template using IsOrtFloat8Type = boost::mp11::mp_contains; #endif +template +struct IsStandardIntegerType { + static constexpr bool value = + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value; +}; + +template +struct IsStandardFloatType { + static constexpr bool value = + std::is_same::value || + std::is_same::value; +}; + // string cast helpers // Note: when C++17 is available, use functions @@ -451,48 +471,6 @@ struct ToInt4ElementConverter { } }; -// Check if a type is one of the integer types -template -struct is_standard_integer { - static constexpr bool value = - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; -}; - -// Check if a type is one of the floating point types -template -struct is_standard_float { - static constexpr bool value = - std::is_same::value || - std::is_same::value; -}; - -// Check if a type is one of the half-precision float types -template -struct is_half_float { - static constexpr bool value = - std::is_same::value || - std::is_same::value; -}; - -#if !defined(DISABLE_FLOAT8_TYPES) -// Check if a type is one of the 8-bit float types -template -struct is_float8_type { - static constexpr bool value = - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; -}; -#endif - } // anonymous namespace // generic tensor X -> Y @@ -550,7 +528,7 @@ struct TensorCaster { template struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -571,7 +549,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -596,7 +574,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -618,7 +596,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -680,7 +658,7 @@ struct TensorCaster { template struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -701,7 +679,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -728,7 +706,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -750,7 +728,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -812,7 +790,7 @@ struct TensorCaster { template struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -838,7 +816,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -864,7 +842,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -941,7 +919,7 @@ struct TensorCaster { #if !defined(DISABLE_FLOAT8_TYPES) template struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -969,7 +947,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -995,7 +973,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -1021,7 +999,7 @@ struct TensorCaster struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -1098,7 +1076,7 @@ struct TensorCaster { #if !defined(DISABLE_FLOAT8_TYPES) template struct TensorCaster::value>::type> { + typename std::enable_if::value>::type> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); From 1d02ca1d527c4f6efabd7d88ff2823acf46a3b9c Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 10:35:35 -0700 Subject: [PATCH 18/88] fix multiple partial specializations issue for MLFloat16 --- .../core/providers/cpu/tensor/cast_op.cc | 41 +++++++------------ 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 37e93dfdf65a0..0f05c44965eb7 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -389,17 +389,6 @@ struct ToInt4ElementConverter { } }; -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const MLFloat16& val) { - return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); - } - - static uint8_t ConvertToUInt4(const MLFloat16& val) { - return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); - } -}; - #if !defined(DISABLE_FLOAT8_TYPES) template <> @@ -840,27 +829,26 @@ struct TensorCaster -struct TensorCaster::value>::type> { +template <> +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); + "The output Int4x2 tensor size is invalid for casting from BFloat16"); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); out_data[i / 2] = Int4x2(low_val, high_val); } if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); out_data[i / 2] = Int4x2(low_val, 0); } } @@ -997,27 +985,26 @@ struct TensorCaster -struct TensorCaster::value>::type> { +template <> +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); const size_t in_shape_size = narrow(shape.Size()); const size_t out_shape_size = narrow(out.Shape().Size()); ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); + "The output UInt4x2 tensor size is invalid for casting from BFloat16"); size_t i = 0; for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i / 2] = UInt4x2(low_val, high_val); } if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i / 2] = UInt4x2(low_val, 0); } } From 406666fc2c4d0fc3da07fc488a506bdd9e8b3bf1 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 10:51:30 -0700 Subject: [PATCH 19/88] clean up std::string specializtions --- .../core/providers/cpu/tensor/cast_op.cc | 120 ------------------ 1 file changed, 120 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 0f05c44965eb7..e18b0c10af0ec 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -297,13 +297,6 @@ struct Int4ElementConverter { } }; -template <> -struct Int4ElementConverter { - static std::string Convert(int8_t val) { - return std::to_string(static_cast(val)); - } -}; - #if !defined(DISABLE_FLOAT8_TYPES) template <> @@ -437,29 +430,6 @@ struct ToInt4ElementConverter { #endif -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const std::string& val) { - int result; - try { - result = std::stoi(val); - } catch (...) { - result = 0; - } - return std::clamp(result, -8, 7); - } - - static uint8_t ConvertToUInt4(const std::string& val) { - unsigned int result; - try { - result = std::stoul(val); - } catch (...) { - result = 0; - } - return std::min(result, 15u); - } -}; - } // anonymous namespace // generic tensor X -> Y @@ -625,26 +595,6 @@ struct TensorCaster { } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - template struct TensorCaster::value>::type> { @@ -757,26 +707,6 @@ struct TensorCaster { } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - template struct TensorCaster::value>::type> { @@ -879,31 +809,6 @@ struct TensorCaster { } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from string."); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); - } - - if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); - } - } -}; - #if !defined(DISABLE_FLOAT8_TYPES) template struct TensorCaster { } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from string."); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); - } - - if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); - } - } -}; - #if !defined(DISABLE_FLOAT8_TYPES) template struct TensorCaster Date: Mon, 9 Jun 2025 11:06:53 -0700 Subject: [PATCH 20/88] Add specializations for Int4x2 -> UInt4x2 and UInt4x2 -> Int4x2 --- .../core/providers/cpu/tensor/cast_op.cc | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index e18b0c10af0ec..137a85c79013b 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -500,6 +500,7 @@ struct TensorCaster(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } @@ -519,8 +520,6 @@ struct TensorCaster(in_data[i].bits_); - - // Extract signed high and low nibble int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; @@ -546,6 +545,7 @@ struct TensorCaster(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } @@ -568,6 +568,7 @@ struct TensorCaster(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } @@ -589,12 +590,61 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); int8_t high_nibble = static_cast(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; + out_data[2 * i] = low_nibble != 0; out_data[2 * i + 1] = high_nibble != 0; } } }; +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(in_shape_size == out_shape_size, + "The output UInt4x2 tensor size doesn't match input Int4x2 tensor size."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + int8_t high_nibble = static_cast(packed) >> 4; + int8_t low_nibble = static_cast(packed << 4) >> 4; + + // Convert to unsigned by clamping at 0 + uint8_t high_unsigned = static_cast(std::max(0, static_cast(high_nibble)) & 0x0F); + uint8_t low_unsigned = static_cast(std::max(0, static_cast(low_nibble)) & 0x0F); + out_data[i] = UInt4x2(low_unsigned, high_unsigned); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(in_shape_size == out_shape_size, + "The output Int4x2 tensor size doesn't match input UInt4x2 tensor size."); + + for (size_t i = 0; i < in_shape_size; ++i) { + const uint8_t packed = static_cast(in_data[i].bits_); + uint8_t high_nibble = (packed >> 4) & 0x0F; + uint8_t low_nibble = packed & 0x0F; + + // Convert to signed by clamping to int4 range (-8 to 7) + int8_t high_signed = std::clamp(static_cast(high_nibble), static_cast(-8), static_cast(7)); + int8_t low_signed = std::clamp(static_cast(low_nibble), static_cast(-8), static_cast(7)); + out_data[i] = Int4x2(low_signed, high_signed); + } + } +}; + template struct TensorCaster::value>::type> { @@ -610,6 +660,7 @@ struct TensorCaster(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } @@ -631,8 +682,6 @@ struct TensorCaster(in_data[i].bits_); - - // Extract unsigned high and low nibble uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; @@ -658,6 +707,7 @@ struct TensorCaster(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } @@ -680,6 +730,7 @@ struct TensorCaster(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); } @@ -688,7 +739,7 @@ struct TensorCaster -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -701,6 +752,7 @@ struct TensorCaster { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; + out_data[2 * i] = low_nibble != 0; out_data[2 * i + 1] = high_nibble != 0; } @@ -837,7 +889,6 @@ struct TensorCaster struct TensorCaster::value>::type> { From bcb24127a02bd5877a07f449a73a6a91bc7c2ef3 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 11:52:24 -0700 Subject: [PATCH 21/88] more concise Int4ElementConverter --- .../core/providers/cpu/tensor/cast_op.cc | 56 ++++--------------- 1 file changed, 10 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 137a85c79013b..6721bb2918280 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -278,55 +278,19 @@ namespace { template struct Int4ElementConverter { static DstType Convert(int8_t val) { - // Default implementation for most numeric types - return static_cast(val); - } -}; - -template <> -struct Int4ElementConverter { - static MLFloat16 Convert(int8_t val) { - return MLFloat16(static_cast(val)); - } -}; - -template <> -struct Int4ElementConverter { - static BFloat16 Convert(int8_t val) { - return BFloat16(static_cast(val)); - } -}; - + if constexpr (IsOrtFloat16Type::value) { + return DstType(static_cast(val)); + } #if !defined(DISABLE_FLOAT8_TYPES) - -template <> -struct Int4ElementConverter { - static Float8E4M3FN Convert(int8_t val) { - return Float8E4M3FN(static_cast(val), true); - } -}; - -template <> -struct Int4ElementConverter { - static Float8E4M3FNUZ Convert(int8_t val) { - return Float8E4M3FNUZ(static_cast(val), true); - } -}; - -template <> -struct Int4ElementConverter { - static Float8E5M2 Convert(int8_t val) { - return Float8E5M2(static_cast(val), true); - } -}; - -template <> -struct Int4ElementConverter { - static Float8E5M2FNUZ Convert(int8_t val) { - return Float8E5M2FNUZ(static_cast(val), true); + else if constexpr (IsOrtFloat8Type::value) { + return DstType(static_cast(val), true); + } +#endif + else { + return static_cast(val); + } } }; -#endif // Helper struct for converting from any type to Int4/UInt4 elements template From cb205e5d6ce04134333785f849966d4eed69e7a7 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 12:05:28 -0700 Subject: [PATCH 22/88] more concise ToInt4ElementConverter --- .../core/providers/cpu/tensor/cast_op.cc | 64 +++---------------- 1 file changed, 10 insertions(+), 54 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 6721bb2918280..7e4c206b8408c 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -293,19 +293,19 @@ struct Int4ElementConverter { }; // Helper struct for converting from any type to Int4/UInt4 elements -template +template struct ToInt4ElementConverter { // Default implementation for most numeric types static int8_t ConvertToInt4(const SrcType& val) { int8_t result = static_cast(val); // Clamp to int4 range (-8 to 7) - return std::clamp(result, static_cast(-8), static_cast(7)); + return std::clamp(result, int8_t(-8), int8_t(7)); } static uint8_t ConvertToUInt4(const SrcType& val) { uint8_t result = static_cast(val); // Clamp to uint4 range (0 to 15) - return std::min(result, static_cast(15)); + return std::min(result, uint8_t(15)); } }; @@ -335,65 +335,21 @@ struct ToInt4ElementConverter { } }; -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const BFloat16& val) { - return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); - } - - static uint8_t ConvertToUInt4(const BFloat16& val) { - return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); - } -}; - +template +struct ToInt4ElementConverter::value #if !defined(DISABLE_FLOAT8_TYPES) - -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const Float8E4M3FN& val) { - return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); - } - - static uint8_t ConvertToUInt4(const Float8E4M3FN& val) { - return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); - } -}; - -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const Float8E4M3FNUZ& val) { - return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); - } - - static uint8_t ConvertToUInt4(const Float8E4M3FNUZ& val) { - return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); - } -}; - -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const Float8E5M2& val) { - return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); - } - - static uint8_t ConvertToUInt4(const Float8E5M2& val) { - return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); - } -}; - -template <> -struct ToInt4ElementConverter { - static int8_t ConvertToInt4(const Float8E5M2FNUZ& val) { + || IsOrtFloat8Type::value +#endif + >> { + static int8_t ConvertToInt4(const SrcType& val) { return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); } - static uint8_t ConvertToUInt4(const Float8E5M2FNUZ& val) { + static uint8_t ConvertToUInt4(const SrcType& val) { return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); } }; -#endif - } // anonymous namespace // generic tensor X -> Y From 1e8f6809ff87ce97350bb45bf4149d8fc7df59b4 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 12:45:09 -0700 Subject: [PATCH 23/88] merge a few TensorCaster specializations --- .../core/providers/cpu/tensor/cast_op.cc | 140 ++++-------------- 1 file changed, 28 insertions(+), 112 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 7e4c206b8408c..140dfd0184a44 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -406,8 +406,12 @@ struct TensorCaster { }; template -struct TensorCaster::value>::type> { +struct TensorCaster::value + || IsOrtFloat16Type::value +#if !defined(DISABLE_FLOAT8_TYPES) + || IsOrtFloat8Type::value +#endif + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -443,61 +447,14 @@ struct TensorCaster(packed) >> 4; int8_t low_nibble = static_cast(packed << 4) >> 4; - // Low nibble first, then high nibble out_data[2 * i] = static_cast(low_nibble); out_data[2 * i + 1] = static_cast(high_nibble); } } }; -template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; - -#if !defined(DISABLE_FLOAT8_TYPES) -template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; - - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; -#endif - template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -541,33 +498,13 @@ struct TensorCaster { } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(in_shape_size == out_shape_size, - "The output Int4x2 tensor size doesn't match input UInt4x2 tensor size."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - - // Convert to signed by clamping to int4 range (-8 to 7) - int8_t high_signed = std::clamp(static_cast(high_nibble), static_cast(-8), static_cast(7)); - int8_t low_signed = std::clamp(static_cast(low_nibble), static_cast(-8), static_cast(7)); - out_data[i] = Int4x2(low_signed, high_signed); - } - } -}; - template -struct TensorCaster::value>::type> { +struct TensorCaster::value + || IsOrtFloat16Type::value +#if !defined(DISABLE_FLOAT8_TYPES) + || IsOrtFloat8Type::value +#endif + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -612,35 +549,12 @@ struct TensorCaster -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; - - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); - } - } -}; -#if !defined(DISABLE_FLOAT8_TYPES) -template -struct TensorCaster::value>::type> { +template <> +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); @@ -651,30 +565,32 @@ struct TensorCaster> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[2 * i] = low_nibble != 0; + out_data[2 * i + 1] = high_nibble != 0; } } }; -#endif template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + const size_t in_shape_size = narrow(shape.Size()); + const size_t out_shape_size = narrow(out.Shape().Size()); + ORT_ENFORCE(in_shape_size == out_shape_size, + "The output Int4x2 tensor size doesn't match input UInt4x2 tensor size."); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); uint8_t high_nibble = (packed >> 4) & 0x0F; uint8_t low_nibble = packed & 0x0F; - out_data[2 * i] = low_nibble != 0; - out_data[2 * i + 1] = high_nibble != 0; + // Convert to signed by clamping to int4 range (-8 to 7) + int8_t high_signed = std::clamp(static_cast(high_nibble), int8_t(-8), int8_t(7)); + int8_t low_signed = std::clamp(static_cast(low_nibble), int8_t(-8), int8_t(7)); + out_data[i] = Int4x2(low_signed, high_signed); } } }; From 2d1ca00bcc1123d72508fb14488ae176bfc33f18 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 13:04:57 -0700 Subject: [PATCH 24/88] merge a few more TensorCaster specializations --- .../core/providers/cpu/tensor/cast_op.cc | 184 ++---------------- 1 file changed, 18 insertions(+), 166 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 140dfd0184a44..e1bb9df1f7692 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -511,7 +511,7 @@ struct TensorCaster(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting to ", typeid(DstType).name()); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); @@ -535,7 +535,7 @@ struct TensorCaster(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The UInt4x2 tensor size is invalid for casting to float."); + "The UInt4x2 tensor size is invalid for casting to float"); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); @@ -558,7 +558,7 @@ struct TensorCaster { const size_t shape_size = narrow(shape.Size()); const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting."); + ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting to bool"); for (size_t i = 0; i < in_shape_size; ++i) { const uint8_t packed = static_cast(in_data[i].bits_); @@ -596,34 +596,13 @@ struct TensorCaster { }; template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); - } - - if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); - } - } -}; - -template -struct TensorCaster::value>::type> { +struct TensorCaster::value + || IsStandardFloatType::value + || std::is_same::value +#if !defined(DISABLE_FLOAT8_TYPES) + || IsOrtFloat8Type::value +#endif + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -648,32 +627,7 @@ struct TensorCaster -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from BFloat16"); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); - } - - if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); - } - } -}; - -template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -697,37 +651,14 @@ struct TensorCaster { } }; -#if !defined(DISABLE_FLOAT8_TYPES) template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); - } - - if (i < in_shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); - } - } -}; +struct TensorCaster::value + || IsStandardFloatType::value + || std::is_same::value +#if !defined(DISABLE_FLOAT8_TYPES) + || IsOrtFloat8Type::value #endif - -template -struct TensorCaster::value>::type> { + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -751,59 +682,8 @@ struct TensorCaster -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); - } - - if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from BFloat16"); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); - } - - if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); - } - } -}; - template <> -struct TensorCaster { +struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -827,34 +707,6 @@ struct TensorCaster { } }; -#if !defined(DISABLE_FLOAT8_TYPES) -template -struct TensorCaster::value>::type> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - - size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); - } - - if (i < in_shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); - } - } -}; -#endif - #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific From 18b3b11ca0917725123ed0540dcd20378e2d6e78 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 13:36:31 -0700 Subject: [PATCH 25/88] styling suggestions, lint --- .../core/providers/cpu/tensor/cast_op.cc | 65 +++++++++---------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index e1bb9df1f7692..412e34da69dda 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -63,21 +63,21 @@ using IsOrtFloat8Type = boost::mp11::mp_contains struct IsStandardIntegerType { static constexpr bool value = - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value; + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; }; template struct IsStandardFloatType { static constexpr bool value = - std::is_same::value || - std::is_same::value; + std::is_same_v || + std::is_same_v; }; // string cast helpers @@ -336,11 +336,12 @@ struct ToInt4ElementConverter { }; template -struct ToInt4ElementConverter::value +struct ToInt4ElementConverter #if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value + || IsOrtFloat8Type::value #endif - >> { + >> { static int8_t ConvertToInt4(const SrcType& val) { return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); } @@ -406,12 +407,12 @@ struct TensorCaster { }; template -struct TensorCaster::value - || IsOrtFloat16Type::value +struct TensorCaster::value || IsOrtFloat16Type::value #if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value + || IsOrtFloat8Type::value #endif - >> { + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -499,12 +500,12 @@ struct TensorCaster { }; template -struct TensorCaster::value - || IsOrtFloat16Type::value +struct TensorCaster::value || IsOrtFloat16Type::value #if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value + || IsOrtFloat8Type::value #endif - >> { + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -525,8 +526,7 @@ struct TensorCaster -struct TensorCaster::value>::type> { +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -549,7 +549,6 @@ struct TensorCaster struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -596,13 +595,12 @@ struct TensorCaster { }; template -struct TensorCaster::value - || IsStandardFloatType::value - || std::is_same::value +struct TensorCaster::value || IsStandardFloatType::value || std::is_same_v #if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value + || IsOrtFloat8Type::value #endif - >> { + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -652,13 +650,12 @@ struct TensorCaster { }; template -struct TensorCaster::value - || IsStandardFloatType::value - || std::is_same::value +struct TensorCaster::value || IsStandardFloatType::value || std::is_same_v #if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value + || IsOrtFloat8Type::value #endif - >> { + >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); From 8ceae1b625236f785ca142fecf37a42c1109bf4f Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 15:27:30 -0700 Subject: [PATCH 26/88] update iterations over the input --- .../core/providers/cpu/tensor/cast_op.cc | 135 +++++------------- 1 file changed, 37 insertions(+), 98 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 412e34da69dda..a0381ef23d9e9 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -417,14 +417,9 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); @@ -434,19 +429,14 @@ struct TensorCaster struct TensorCaster::value>::type> { + std::enable_if_t::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); out_data[2 * i] = static_cast(low_nibble); out_data[2 * i + 1] = static_cast(high_nibble); @@ -460,14 +450,9 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The Int4x2 tensor size is invalid for casting."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); out_data[2 * i] = low_nibble != 0; out_data[2 * i + 1] = high_nibble != 0; @@ -481,15 +466,9 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(in_shape_size == out_shape_size, - "The output UInt4x2 tensor size doesn't match input Int4x2 tensor size."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - int8_t high_nibble = static_cast(packed) >> 4; - int8_t low_nibble = static_cast(packed << 4) >> 4; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); // Convert to unsigned by clamping at 0 uint8_t high_unsigned = static_cast(std::max(0, static_cast(high_nibble)) & 0x0F); @@ -510,14 +489,9 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting to ", typeid(DstType).name()); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); @@ -531,18 +505,10 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - // Confirm we can unpack the uint4 - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, - "The UInt4x2 tensor size is invalid for casting to float"); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); - // Low nibble first, then high nibble out_data[2 * i] = static_cast(low_nibble); out_data[2 * i + 1] = static_cast(high_nibble); } @@ -555,14 +521,9 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - const size_t in_shape_size = narrow(in.Shape().Size()); - ORT_ENFORCE(in_shape_size * 2 == shape_size, "The UInt4x2 tensor size is invalid for casting to bool"); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); out_data[2 * i] = low_nibble != 0; out_data[2 * i + 1] = high_nibble != 0; @@ -576,15 +537,9 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(in_shape_size == out_shape_size, - "The output Int4x2 tensor size doesn't match input UInt4x2 tensor size."); - - for (size_t i = 0; i < in_shape_size; ++i) { - const uint8_t packed = static_cast(in_data[i].bits_); - uint8_t high_nibble = (packed >> 4) & 0x0F; - uint8_t low_nibble = packed & 0x0F; + for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + auto low_nibble = in_data[i].GetElem(0); + auto high_nibble = in_data[i].GetElem(1); // Convert to signed by clamping to int4 range (-8 to 7) int8_t high_signed = std::clamp(static_cast(high_nibble), int8_t(-8), int8_t(7)); @@ -605,19 +560,15 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - + const size_t shape_size = narrow(shape.Size()); size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { + for (; i < shape_size - 1; i += 2) { int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); out_data[i / 2] = Int4x2(low_val, high_val); } - if (i < in_shape_size) { + if (i < shape_size) { int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); out_data[i / 2] = Int4x2(low_val, 0); } @@ -630,19 +581,15 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output Int4x2 tensor size is invalid for casting from bool."); - + const size_t shape_size = narrow(shape.Size()); size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { + for (; i < shape_size - 1; i += 2) { int8_t low_val = in_data[i] ? 1 : 0; int8_t high_val = in_data[i + 1] ? 1 : 0; out_data[i / 2] = Int4x2(low_val, high_val); } - if (i < in_shape_size) { + if (i < shape_size) { int8_t low_val = in_data[i] ? 1 : 0; out_data[i / 2] = Int4x2(low_val, 0); } @@ -660,19 +607,15 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from ", typeid(SrcType).name()); - + const size_t shape_size = narrow(shape.Size()); size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { + for (; i < shape_size - 1; i += 2) { uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i / 2] = UInt4x2(low_val, high_val); } - if (i < in_shape_size) { + if (i < shape_size) { uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i / 2] = UInt4x2(low_val, 0); } @@ -685,19 +628,15 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t in_shape_size = narrow(shape.Size()); - const size_t out_shape_size = narrow(out.Shape().Size()); - ORT_ENFORCE(out_shape_size * 2 >= in_shape_size, - "The output UInt4x2 tensor size is invalid for casting from bool."); - + const size_t shape_size = narrow(shape.Size()); size_t i = 0; - for (; i < in_shape_size - 1; i += 2) { + for (; i < shape_size - 1; i += 2) { uint8_t low_val = in_data[i] ? 1 : 0; uint8_t high_val = in_data[i + 1] ? 1 : 0; out_data[i / 2] = UInt4x2(low_val, high_val); } - if (i < in_shape_size) { + if (i < shape_size) { uint8_t low_val = in_data[i] ? 1 : 0; out_data[i / 2] = UInt4x2(low_val, 0); } From bb806e43e93f7ad54949cc73f57844c47315dc98 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 15:35:57 -0700 Subject: [PATCH 27/88] update iteration --- .../core/providers/cpu/tensor/cast_op.cc | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index a0381ef23d9e9..55fadf7aa807c 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -417,7 +417,7 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -434,7 +434,7 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -450,7 +450,7 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -466,7 +466,7 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -489,7 +489,7 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -505,7 +505,7 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -521,7 +521,7 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -537,7 +537,7 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); i += 2) { + for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -565,12 +565,12 @@ struct TensorCaster::ConvertToInt4(in_data[i]); int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); - out_data[i / 2] = Int4x2(low_val, high_val); + out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); - out_data[i / 2] = Int4x2(low_val, 0); + out_data[i >> 1] = Int4x2(low_val, 0); } } }; @@ -586,12 +586,12 @@ struct TensorCaster { for (; i < shape_size - 1; i += 2) { int8_t low_val = in_data[i] ? 1 : 0; int8_t high_val = in_data[i + 1] ? 1 : 0; - out_data[i / 2] = Int4x2(low_val, high_val); + out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { int8_t low_val = in_data[i] ? 1 : 0; - out_data[i / 2] = Int4x2(low_val, 0); + out_data[i >> 1] = Int4x2(low_val, 0); } } }; @@ -612,12 +612,12 @@ struct TensorCaster::ConvertToUInt4(in_data[i]); uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); - out_data[i / 2] = UInt4x2(low_val, high_val); + out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); - out_data[i / 2] = UInt4x2(low_val, 0); + out_data[i >> 1] = UInt4x2(low_val, 0); } } }; @@ -633,12 +633,12 @@ struct TensorCaster { for (; i < shape_size - 1; i += 2) { uint8_t low_val = in_data[i] ? 1 : 0; uint8_t high_val = in_data[i + 1] ? 1 : 0; - out_data[i / 2] = UInt4x2(low_val, high_val); + out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { uint8_t low_val = in_data[i] ? 1 : 0; - out_data[i / 2] = UInt4x2(low_val, 0); + out_data[i >> 1] = UInt4x2(low_val, 0); } } }; From 68259740cf0253777c62ef152f7b5e698d1268ad Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 9 Jun 2025 18:13:39 -0700 Subject: [PATCH 28/88] Update opset in unit tests to support int4 --- onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index fbb508bc2d034..0f8f7ba48f7b4 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -57,7 +57,7 @@ void TestCastOp(gsl::span input, const BaseTester::DimsVariant& dimensions, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& expected_failure_string = "", - int opset = 13, + int opset = 21, Saturate saturate = Saturate::None) { OpTester test("Cast", opset); test.AddAttribute("to", utils::ToTensorProtoElementType()); From b980762857c9a1b7e052c918408435ebd59f7a77 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 10 Jun 2025 11:36:53 -0700 Subject: [PATCH 29/88] Add unit tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 655 +++++++++++++++++- 1 file changed, 653 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 0f8f7ba48f7b4..c05702b477f06 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -207,6 +207,249 @@ TEST(CastOpTest, ToString) { TestCastOp(gsl::make_span(int_16_input), gsl::make_span(int_string_data), shape); } +TEST(CastOpTest, Int4x2ToInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int8_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int8_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // -8 becomes 248, -1 becomes 255, etc. + const std::vector expected_uint8_output = {248, 7, 0, 255, 3, 251, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint8_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int16_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int16_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint16_output = {65528, 7, 0, 65535, 3, 65531, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint16_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int32_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int32_output), shape); +} + +TEST(CastOpTest, Int4x2ToUInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint32_output = {4294967288, 7, 0, 4294967295, 3, 4294967291, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + const std::vector expected_int64_output = {-8, 7, 0, -1, 3, -5, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int64_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint8_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint8_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int8_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int8_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint16_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int16_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint32_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint32_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int32_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int32_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_uint64_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint64_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_int64_output = {0, 15, 1, 14, 7, 8, 3, 12}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int64_output), shape); +} + TEST(CastOpTest, Int4x2ToFloat) { // GIVEN const std::vector shape{2, 2, 2}; @@ -215,7 +458,7 @@ TEST(CastOpTest, Int4x2ToFloat) { Int4x2(-3, -4), Int4x2(5, -6), Int4x2(-8, 7)}; - // There will be twice as many unpacked elements + const std::vector expected_float_output = {1.0f, 2.0f, -3.0f, -4.0f, 5.0f, -6.0f, -8.0f, 7.0f}; // WHEN, THEN @@ -230,13 +473,302 @@ TEST(CastOpTest, UInt4x2ToFloat) { UInt4x2(2, 3), UInt4x2(7, 8), UInt4x2(14, 15)}; - // There will be twice as many unpacked elements + const std::vector expected_float_output = {0.0f, 1.0f, 2.0f, 3.0f, 7.0f, 8.0f, 14.0f, 15.0f}; // WHEN, THEN TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float_output), shape); } +TEST(CastOpTest, Int4x2ToDouble) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -3), // zero and negative + Int4x2(4, -2), // positive and negative + Int4x2(1, 6) // both positive + }; + + const std::vector expected_double_output = {-8.0, 7.0, 0.0, -3.0, 4.0, -2.0, 1.0, 6.0}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_double_output), shape); +} + +TEST(CastOpTest, UInt4x2ToDouble) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(1, 14), // small and large + UInt4x2(7, 8), // middle values + UInt4x2(3, 12) // mixed values + }; + + const std::vector expected_double_output = {0.0, 15.0, 1.0, 14.0, 7.0, 8.0, 3.0, 12.0}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_double_output), shape); +} + +TEST(CastOpTest, Int4x2ToMLFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + const std::vector expected_float16_output = + CastedValues( + gsl::make_span( + std::vector{-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToMLFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + const std::vector expected_float16_output = + CastedValues( + gsl::make_span( + std::vector{0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_float16_output), shape); +} + +TEST(CastOpTest, Int4x2ToBFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + const std::vector expected_bfloat16_output = + CastedValues( + gsl::make_span( + std::vector{-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_bfloat16_output), shape); +} + +TEST(CastOpTest, UInt4x2ToBFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + const std::vector expected_bfloat16_output = + CastedValues( + gsl::make_span( + std::vector{0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f})); + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_bfloat16_output), shape); +} + +TEST(CastOpTest, Int8ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int8_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + + // values outside int4 range get clamped + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt8ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint8_input = {20, 15, 0, 1, 7, 25, 3, 12}; + + // values outside uint4 range get clamped + const std::vector expected_uint4x2_output = { + UInt4x2(15, 15), // 20 clamped to 15 + UInt4x2(0, 1), + UInt4x2(7, 15), // 25 clamped to 15 + UInt4x2(3, 12)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint8_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int16ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int16_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + + // values outside int4 range get clamped + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int16_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt16ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint16_input = {20, 15, 0, 1, 7, 25, 3, 12}; + + // values outside uint4 range get clamped + const std::vector expected_uint4x2_output = { + UInt4x2(15, 15), // 20 clamped to 15 + UInt4x2(0, 1), + UInt4x2(7, 15), // 25 clamped to 15 + UInt4x2(3, 12)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint16_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int32ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int32_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + + // values outside int4 range get clamped + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int32_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Int32ToInt4x2OddNumberOfElements) { + // GIVEN + const std::vector odd_shape{5}; + const std::vector odd_input = {-8, 7, 0, -1, 3}; + + const std::vector expected_odd_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, 0) // last element paired with 0 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape); +} + +TEST(CastOpTest, Int32ToInt4x2EmptyTensor) { + // GIVEN + const std::vector empty_shape{0}; + const std::vector empty_input = {}; + const std::vector empty_output = {}; + + // WHEN, THEN + TestCastOp(gsl::make_span(empty_input), gsl::make_span(empty_output), empty_shape); +} + +TEST(CastOpTest, UInt32ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint32_input = {20, 15, 0, 1, 7, 25, 3, 12}; + + // values outside uint4 range get clamped + const std::vector expected_uint4x2_output = { + UInt4x2(15, 15), // 20 clamped to 15 + UInt4x2(0, 1), + UInt4x2(7, 15), // 25 clamped to 15 + UInt4x2(3, 12)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint32_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int64ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int64_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + + // values outside int4 range get clamped + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int64_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt64ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint64_input = {20, 15, 0, 1, 7, 25, 3, 12}; + + // values outside uint4 range get clamped + const std::vector expected_uint4x2_output = { + UInt4x2(15, 15), // 20 clamped to 15 + UInt4x2(0, 1), + UInt4x2(7, 15), // 25 clamped to 15 + UInt4x2(3, 12)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint64_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, FloatToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector float_input = {-10.7f, 15.3f, 0.4f, -1.6f, 3.8f, -5.2f, 6.1f, 2.9f}; + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), // -10.7 rounded to -11, clamped to -8; 15.3 rounded to 15, clamped to 7 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 + Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 + Int4x2(6, 3) // 6.1 rounded to 6; 2.9 rounded to 3 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(float_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, DoubleToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector double_input = {20.7, 15.3, 0.4, 1.6, 7.8, 25.2, 3.1, 12.9}; + + const std::vector expected_uint4x2_output = { + UInt4x2(15, 15), // 20.7 rounded to 21, clamped to 15; 15.3 rounded to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 15), // 7.8 rounded to 8; 25.2 rounded to 25, clamped to 15 + UInt4x2(3, 13) // 3.1 rounded to 3; 12.9 rounded to 13 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(double_input), gsl::make_span(expected_uint4x2_output), shape); +} + #if !defined(DISABLE_FLOAT8_TYPES) template @@ -299,6 +831,125 @@ TEST(CastOpTest, ToFloat8E5M2FNUZ) { } } +TEST(CastOpTest, Int4x2ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2) + }; + + std::vector expected_float8_output; + expected_float8_output.reserve(8); + const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : float_values) { + expected_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape); +} + +TEST(CastOpTest, UInt4x2ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + std::vector expected_uint_float8_output; + expected_uint_float8_output.reserve(8); + const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_float_values) { + expected_uint_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape); +} + +TEST(CastOpTest, Int4x2ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + std::vector expected_float8e5m2_output; + expected_float8e5m2_output.reserve(8); + const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : float_values) { + expected_float8e5m2_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape); +} + +TEST(CastOpTest, UInt4x2ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + std::vector expected_uint_float8e5m2_output; + expected_uint_float8e5m2_output.reserve(8); + const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_float_values) { + expected_uint_float8e5m2_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector uint_float8_input; + const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_input_values) { + uint_float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_uint4x2_output = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint_float8_input), gsl::make_span(expected_uint4x2_output), shape); +} + #endif } // namespace test From 207e5df0eac6fb413a08b1725630b2017011735b Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 10 Jun 2025 14:13:27 -0700 Subject: [PATCH 30/88] add more unit tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 113 +++++++++++++++++- 1 file changed, 110 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index c05702b477f06..5022a5683ccdc 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -450,6 +450,40 @@ TEST(CastOpTest, UInt4x2ToInt64) { TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int64_output), shape); } +TEST(CastOpTest, Int4x2ToBool) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(0, -1), // zero and non-zero + Int4x2(7, 0), // non-zero and zero + Int4x2(-8, 3), // both non-zero + Int4x2(0, 0) // both zero + }; + + const bool bool_output[] = {false, true, true, false, true, true, false, false}; + const gsl::span expected_bool_output_span(bool_output); + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), expected_bool_output_span, shape); +} + +TEST(CastOpTest, UInt4x2ToBool) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 1), // zero and non-zero + UInt4x2(15, 0), // non-zero and zero + UInt4x2(8, 7), // both non-zero + UInt4x2(0, 0) // both zero + }; + + const bool bool_output[] = {false, true, true, false, true, true, false, false}; + const gsl::span expected_bool_output_span(bool_output); + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), expected_bool_output_span, shape); +} + TEST(CastOpTest, Int4x2ToFloat) { // GIVEN const std::vector shape{2, 2, 2}; @@ -584,6 +618,48 @@ TEST(CastOpTest, UInt4x2ToBFloat16) { TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_bfloat16_output), shape); } +TEST(CastOpTest, Int4x2ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // negative values get clamped to 0 + Int4x2(0, -1), // -1 becomes 0 + Int4x2(3, -5), // -5 becomes 0 + Int4x2(6, 2) // positive values remain + }; + + const std::vector expected_uint4x2_output = { + UInt4x2(0, 7), // -8 clamped to 0 + UInt4x2(0, 0), // -1 clamped to 0 + UInt4x2(3, 0), // -5 clamped to 0 + UInt4x2(6, 2) // unchanged + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, UInt4x2ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // 15 is out of int4 range, should be clamped to 7 + UInt4x2(1, 14), // 14 is out of int4 range, should be clamped to 7 + UInt4x2(7, 8), // 8 is out of int4 range, should be clamped to 7 + UInt4x2(3, 6) // both within range + }; + + const std::vector expected_int4x2_output = { + Int4x2(0, 7), // 15 clamped to 7 + Int4x2(1, 7), // 14 clamped to 7 + Int4x2(7, 7), // 8 clamped to 7 + Int4x2(3, 6) // unchanged + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_int4x2_output), shape); +} + TEST(CastOpTest, Int8ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; @@ -641,7 +717,7 @@ TEST(CastOpTest, UInt16ToUInt4x2) { const std::vector expected_uint4x2_output = { UInt4x2(15, 15), // 20 clamped to 15 UInt4x2(0, 1), - UInt4x2(7, 15), // 25 clamped to 15 + UInt4x2(7, 15), // 25 clamped to 15 UInt4x2(3, 12)}; // WHEN, THEN @@ -769,6 +845,38 @@ TEST(CastOpTest, DoubleToUInt4x2) { TestCastOp(gsl::make_span(double_input), gsl::make_span(expected_uint4x2_output), shape); } +TEST(CastOpTest, BoolToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const bool bool_input[] = {false, true, true, false, false, true, true, true}; + const gsl::span bool_input_span(bool_input); + + const std::vector expected_int4x2_output = { + Int4x2(0, 1), + Int4x2(1, 0), + Int4x2(0, 1), + Int4x2(1, 1)}; + + // WHEN, THEN + TestCastOp(bool_input_span, gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, BoolToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const bool bool_input[] = {false, true, true, false, false, true, true, true}; + const gsl::span bool_input_span(bool_input); + + const std::vector expected_uint4x2_output = { + UInt4x2(0, 1), + UInt4x2(1, 0), + UInt4x2(0, 1), + UInt4x2(1, 1)}; + + // WHEN, THEN + TestCastOp(bool_input_span, gsl::make_span(expected_uint4x2_output), shape); +} + #if !defined(DISABLE_FLOAT8_TYPES) template @@ -838,8 +946,7 @@ TEST(CastOpTest, Int4x2ToFloat8E4M3FN) { Int4x2(-8, 7), Int4x2(0, -1), Int4x2(3, -5), - Int4x2(6, 2) - }; + Int4x2(6, 2)}; std::vector expected_float8_output; expected_float8_output.reserve(8); From 8ab01619b4795c6b6a47c122451ad52f68bb6376 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 10 Jun 2025 15:02:02 -0700 Subject: [PATCH 31/88] update string implementation and add tests --- .../core/providers/cpu/tensor/cast_op.cc | 168 ++++++++++-------- .../test/providers/cpu/tensor/cast_op_test.cc | 106 +++++++++++ 2 files changed, 199 insertions(+), 75 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 55fadf7aa807c..5118d11a3eae6 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -144,21 +144,6 @@ CastToString(const SrcType& input, std::string& output) { CastToString(static_cast(input), output); } -inline void CastToString(Int4x2 value, std::string& out) { - // Int4x2 contains two 4-bit signed integers - // Show both values as [first,second] - auto val0 = value.GetElem(0); // First 4-bit value - auto val1 = value.GetElem(1); // Second 4-bit value - out = "[" + std::to_string(static_cast(val0)) + "," + std::to_string(static_cast(val1)) + "]"; -} - -inline void CastToString(UInt4x2 value, std::string& out) { - // UInt4x2 contains two 4-bit unsigned integers - auto val0 = value.GetElem(0); // First 4-bit value - auto val1 = value.GetElem(1); // Second 4-bit value - out = "[" + std::to_string(static_cast(val0)) + "," + std::to_string(static_cast(val1)) + "]"; -} - template typename std::enable_if::value, void>::type CastFromString(const std::string& input, DstType& output) { @@ -183,66 +168,6 @@ CastFromString(const std::string& input, DstType& output) { output = gsl::narrow_cast(std::stoll(input)); } -inline void CastFromString(const std::string& in, Int4x2& out) { - // Parse string format: "[-3,7]" or "-3,7" or just "-3" (single value) - std::string trimmed = in; - - // Remove brackets if present - if (!trimmed.empty() && trimmed.front() == '[') { - trimmed = trimmed.substr(1); - } - if (!trimmed.empty() && trimmed.back() == ']') { - trimmed = trimmed.substr(0, trimmed.length() - 1); - } - - // Find comma separator - size_t comma_pos = trimmed.find(','); - int8_t val0 = 0, val1 = 0; - if (comma_pos != std::string::npos) { - // Two values: "val0,val1" - std::string val0_str = trimmed.substr(0, comma_pos); - std::string val1_str = trimmed.substr(comma_pos + 1); - - val0 = static_cast(std::clamp(std::stoi(val0_str), -8, 7)); - val1 = static_cast(std::clamp(std::stoi(val1_str), -8, 7)); - } else { - // Single value - use for both elements - val0 = val1 = static_cast(std::clamp(std::stoi(trimmed), -8, 7)); - } - - out = Int4x2(val0, val1); -} - -inline void CastFromString(const std::string& in, UInt4x2& out) { - // Parse string format: "[5,12]" or "5,12" or just "5" (single value) - std::string trimmed = in; - - // Remove brackets if present - if (!trimmed.empty() && trimmed.front() == '[') { - trimmed = trimmed.substr(1); - } - if (!trimmed.empty() && trimmed.back() == ']') { - trimmed = trimmed.substr(0, trimmed.length() - 1); - } - - // Find comma separator - size_t comma_pos = trimmed.find(','); - uint8_t val0 = 0, val1 = 0; - if (comma_pos != std::string::npos) { - // Two values: "val0,val1" - std::string val0_str = trimmed.substr(0, comma_pos); - std::string val1_str = trimmed.substr(comma_pos + 1); - - val0 = static_cast(std::clamp(std::stoi(val0_str), 0, 15)); - val1 = static_cast(std::clamp(std::stoi(val1_str), 0, 15)); - } else { - // Single value - use for both elements - val0 = val1 = static_cast(std::clamp(std::stoi(trimmed), 0, 15)); - } - - out = UInt4x2(val0, val1); -} - template #if !defined(DISABLE_FLOAT8_TYPES) typename std::enable_if::value || IsOrtFloat8Type::value, void>::type @@ -395,6 +320,99 @@ struct TensorCaster { } }; +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Unpack each Int4x2 into two separate string elements + size_t out_idx = 0; + for (size_t i = 0; i < narrow(shape.Size()) >> 1; i++) { + auto val0 = in_data[i].GetElem(0); + auto val1 = in_data[i].GetElem(1); + + out_data[out_idx++] = std::to_string(static_cast(val0)); + out_data[out_idx++] = std::to_string(static_cast(val1)); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Unpack each UInt4x2 into two separate string elements + size_t out_idx = 0; + for (size_t i = 0; i < narrow(shape.Size()) >> 1; i++) { + auto val0 = in_data[i].GetElem(0); + auto val1 = in_data[i].GetElem(1); + + out_data[out_idx++] = std::to_string(static_cast(val0)); + out_data[out_idx++] = std::to_string(static_cast(val1)); + } + } +}; + +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Every 2 strings combine into 1 Int4x2 + const size_t shape_size = narrow(shape.Size()); + size_t i = 0; + for (; i < shape_size - 1; i += 2) { + // Parse each string and clamp to int4 range (-8 to 7) + int v0 = std::stoi(in_data[i]); + int v1 = std::stoi(in_data[i + 1]); + int8_t val0 = static_cast(std::clamp(v0, -8, 7)); + int8_t val1 = static_cast(std::clamp(v1, -8, 7)); + + out_data[i >> 1] = Int4x2(val0, val1); + } + + // Handle odd number of elements - pad with 0 + if (i < shape_size) { + int v0 = std::stoi(in_data[i]); + int8_t val0 = static_cast(std::clamp(v0, -8, 7)); + out_data[i >> 1] = Int4x2(val0, 0); + } + } +}; + +// TensorCaster specialization for string to UInt4x2 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + // Every 2 strings combine into 1 UInt4x2 + const size_t shape_size = narrow(shape.Size()); + size_t i = 0; + for (; i < shape_size - 1; i += 2) { + // Parse each string and clamp to uint4 range (0 to 15) + int v0 = std::stoi(in_data[i]); + int v1 = std::stoi(in_data[i + 1]); + uint8_t val0 = static_cast(std::clamp(v0, 0, 15)); + uint8_t val1 = static_cast(std::clamp(v1, 0, 15)); + + out_data[i >> 1] = UInt4x2(val0, val1); + } + + // Handle odd number of elements - pad with 0 + if (i < shape_size) { + int v0 = std::stoi(in_data[i]); + uint8_t val0 = static_cast(std::clamp(v0, 0, 15)); + out_data[i >> 1] = UInt4x2(val0, 0); + } + } +}; + // tensor MLFloat16 -> float template <> struct TensorCaster { diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 5022a5683ccdc..44497fd795b4b 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -618,6 +618,50 @@ TEST(CastOpTest, UInt4x2ToBFloat16) { TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_bfloat16_output), shape); } +TEST(CastOpTest, Int4x2ToString) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // mixed values + Int4x2(6, 2) // positive values + }; + + // Each Int4x2 becomes two string values + const std::vector expected_output = { + "-8", "7", // from first Int4x2 + "0", "-1", // from second Int4x2 + "3", "-5", // from third Int4x2 + "6", "2" // from fourth Int4x2 + }; + + // WHEN, THEN + TestCastOp(gsl::span(int4x2_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, UInt4x2ToString) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), // boundary values + UInt4x2(8, 7), // mid-range values + UInt4x2(3, 12), // mixed values + UInt4x2(10, 5) // other values + }; + + // Each UInt4x2 becomes two string values + const std::vector expected_output = { + "0", "15", // from first UInt4x2 + "8", "7", // from second UInt4x2 + "3", "12", // from third UInt4x2 + "10", "5" // from fourth UInt4x2 + }; + + // WHEN, THEN + TestCastOp(gsl::span(uint4x2_input), gsl::span(expected_output), shape); +} + TEST(CastOpTest, Int4x2ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; @@ -877,6 +921,68 @@ TEST(CastOpTest, BoolToUInt4x2) { TestCastOp(bool_input_span, gsl::make_span(expected_uint4x2_output), shape); } +TEST(CastOpTest, StringToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "-8", "7", // boundary values + "0", "-1", // zero and negative + "3", "-5", // mixed values + "6", "2" // positive values + }; + + const std::vector expected_output { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, StringToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "0", "15", // boundary values + "8", "7", // mid-range values + "3", "12", // mixed values + "10", "5" // other values + }; + + const std::vector expected_output{ + UInt4x2(0, 15), + UInt4x2(8, 7), + UInt4x2(3, 12), + UInt4x2(10, 5)}; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, String2UInt4x2BoundaryValuesClamping) { + // GIVEN + // Test string values that need clamping to UInt4x2 range (0-15) + const std::vector shape{3, 2}; + const std::vector string_input = { + "-5", "20", // out of range values that should be clamped + "16", "100", // out of range values that should be clamped + "0", "15" // boundary values that are in range + }; + + // Each pair of strings becomes one UInt4x2 + // Values should be clamped to uint4 range (0-15) + const std::vector expected_output { + UInt4x2(0, 15), // -5 clamped to 0, 20 clamped to 15 + UInt4x2(15, 15), // 16 clamped to 15, 100 clamped to 15 + UInt4x2(0, 15) // 0 and 15 already in range + }; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); +} + #if !defined(DISABLE_FLOAT8_TYPES) template From c81c502115a615c7d6572f6ce220210deac590f6 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 10 Jun 2025 15:07:20 -0700 Subject: [PATCH 32/88] lint --- onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 44497fd795b4b..96a6154867e9b 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -931,7 +931,7 @@ TEST(CastOpTest, StringToInt4x2) { "6", "2" // positive values }; - const std::vector expected_output { + const std::vector expected_output{ Int4x2(-8, 7), Int4x2(0, -1), Int4x2(3, -5), @@ -973,7 +973,7 @@ TEST(CastOpTest, String2UInt4x2BoundaryValuesClamping) { // Each pair of strings becomes one UInt4x2 // Values should be clamped to uint4 range (0-15) - const std::vector expected_output { + const std::vector expected_output{ UInt4x2(0, 15), // -5 clamped to 0, 20 clamped to 15 UInt4x2(15, 15), // 16 clamped to 15, 100 clamped to 15 UInt4x2(0, 15) // 0 and 15 already in range From 7e4ce012a4d4f3befb700f3efddcb64180733f14 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 10 Jun 2025 16:06:02 -0700 Subject: [PATCH 33/88] add MLFloat16 tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 109 ++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 96a6154867e9b..0c3c6ccfebbde 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -889,6 +889,115 @@ TEST(CastOpTest, DoubleToUInt4x2) { TestCastOp(gsl::make_span(double_input), gsl::make_span(expected_uint4x2_output), shape); } +TEST(CastOpTest, MLFloat16ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const MLFloat16 mlfloat16_array[8] = { + MLFloat16(static_cast(-8)), + MLFloat16(static_cast(7)), + MLFloat16(static_cast(0)), + MLFloat16(static_cast(-1)), + MLFloat16(static_cast(3)), + MLFloat16(static_cast(-5)), + MLFloat16(static_cast(6)), + MLFloat16(static_cast(2))}; + + const std::vector expected_int4x2 = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 8), + gsl::span(expected_int4x2), + shape); +} + +TEST(CastOpTest, MLFloat16ToUInt4x2) { + // GIVEN + // 8 MLFloat16 values will compress to 4 UInt4x2 values + const std::vector shape{2, 4}; // Shape that contains 8 elements + + // MLFloat16 values: 0, 15, 8, 7, 3, 12, 10, 5 + const MLFloat16 mlfloat16_array[8] = { + MLFloat16(static_cast(0)), + MLFloat16(static_cast(15)), + MLFloat16(static_cast(8)), + MLFloat16(static_cast(7)), + MLFloat16(static_cast(3)), + MLFloat16(static_cast(12)), + MLFloat16(static_cast(10)), + MLFloat16(static_cast(5))}; + + const std::vector expected_uint4x2 = { + UInt4x2(0, 15), + UInt4x2(8, 7), + UInt4x2(3, 12), + UInt4x2(10, 5)}; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 8), + gsl::span(expected_uint4x2), + shape); +} + +TEST(CastOpTest, MLFloat16ToInt4x2BoundaryValuesClamping) { + // GIVEN + // Test MLFloat16 values that need clamping to Int4x2 range (-8 to 7) + const std::vector shape{3, 2}; + const MLFloat16 mlfloat16_array[6] = { + MLFloat16(static_cast(-10)), // Below min, should clamp to -8 + MLFloat16(static_cast(9)), // Above max, should clamp to 7 + MLFloat16(static_cast(-8)), // At min, should remain -8 + MLFloat16(static_cast(7)), // At max, should remain 7 + MLFloat16(static_cast(-0.6f)), // Should round to -1 + MLFloat16(static_cast(1.7f)) // Should round to 2 + }; + + // Values should be clamped to int4 range (-8 to 7) + const std::vector expected_int4x2 = { + Int4x2(-8, 7), // -10 clamped to -8, 9 clamped to 7 + Int4x2(-8, 7), // -8 and 7 already at boundaries + Int4x2(-1, 2) // -0.6 rounds to -1, 1.7 rounds to 2 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 6), + gsl::span(expected_int4x2), + shape); +} + +TEST(CastOpTest, MLFloat16ToUInt4x2BoundaryValuesClamping) { + // GIVEN + // Test MLFloat16 values that need clamping to UInt4x2 range (0 to 15) + const std::vector shape{3, 2}; // Shape that contains 6 elements + const MLFloat16 mlfloat16_array[6] = { + MLFloat16(static_cast(-5)), // Negative, should clamp to 0 + MLFloat16(static_cast(20)), // Above max, should clamp to 15 + MLFloat16(static_cast(0)), // At min, should remain 0 + MLFloat16(static_cast(15)), // At max, should remain 15 + MLFloat16(static_cast(3.4f)), // Should round to 3 + MLFloat16(static_cast(5.7f)) // Should round to 6 + }; + + // Values should be clamped to uint4 range (0-15) + const std::vector expected_uint4x2 = { + UInt4x2(0, 15), // -5 clamped to 0, 20 clamped to 15 + UInt4x2(0, 15), // 0 and 15 already at boundaries + UInt4x2(3, 6) // 3.4 rounds to 3, 5.7 rounds to 6 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(mlfloat16_array, 6), + gsl::span(expected_uint4x2), + shape); +} + TEST(CastOpTest, BoolToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; From 6ec215c80a40f72dd67471e00d6028f395852a01 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 12 Jun 2025 07:52:42 -0700 Subject: [PATCH 34/88] update iteration, add test for odd elements --- .../core/providers/cpu/tensor/cast_op.cc | 92 ++++++++++--------- .../test/providers/cpu/tensor/cast_op_test.cc | 15 +++ 2 files changed, 63 insertions(+), 44 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 5118d11a3eae6..7e65965611a69 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -328,12 +328,11 @@ struct TensorCaster { // Unpack each Int4x2 into two separate string elements size_t out_idx = 0; - for (size_t i = 0; i < narrow(shape.Size()) >> 1; i++) { - auto val0 = in_data[i].GetElem(0); - auto val1 = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[out_idx++] = std::to_string(static_cast(val0)); - out_data[out_idx++] = std::to_string(static_cast(val1)); + out_data[out_idx++] = std::to_string(static_cast(val)); } } }; @@ -346,12 +345,11 @@ struct TensorCaster { // Unpack each UInt4x2 into two separate string elements size_t out_idx = 0; - for (size_t i = 0; i < narrow(shape.Size()) >> 1; i++) { - auto val0 = in_data[i].GetElem(0); - auto val1 = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[out_idx++] = std::to_string(static_cast(val0)); - out_data[out_idx++] = std::to_string(static_cast(val1)); + out_data[out_idx++] = std::to_string(static_cast(val)); } } }; @@ -379,6 +377,7 @@ struct TensorCaster { if (i < shape_size) { int v0 = std::stoi(in_data[i]); int8_t val0 = static_cast(std::clamp(v0, -8, 7)); + out_data[i >> 1] = Int4x2(val0, 0); } } @@ -408,6 +407,7 @@ struct TensorCaster { if (i < shape_size) { int v0 = std::stoi(in_data[i]); uint8_t val0 = static_cast(std::clamp(v0, 0, 15)); + out_data[i >> 1] = UInt4x2(val0, 0); } } @@ -435,12 +435,11 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { - auto low_nibble = in_data[i].GetElem(0); - auto high_nibble = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[i] = Int4ElementConverter::Convert(val); } } }; @@ -452,12 +451,11 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { - auto low_nibble = in_data[i].GetElem(0); - auto high_nibble = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); + out_data[i] = static_cast(val); } } }; @@ -468,12 +466,11 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { - auto low_nibble = in_data[i].GetElem(0); - auto high_nibble = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[2 * i] = low_nibble != 0; - out_data[2 * i + 1] = high_nibble != 0; + out_data[i] = val != 0; } } }; @@ -484,13 +481,14 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { + for (size_t i = 0; i < narrow(shape.Size() + 1) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); // Convert to unsigned by clamping at 0 - uint8_t high_unsigned = static_cast(std::max(0, static_cast(high_nibble)) & 0x0F); uint8_t low_unsigned = static_cast(std::max(0, static_cast(low_nibble)) & 0x0F); + uint8_t high_unsigned = static_cast(std::max(0, static_cast(high_nibble)) & 0x0F); + out_data[i] = UInt4x2(low_unsigned, high_unsigned); } } @@ -507,12 +505,11 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { - auto low_nibble = in_data[i].GetElem(0); - auto high_nibble = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[2 * i] = Int4ElementConverter::Convert(low_nibble); - out_data[2 * i + 1] = Int4ElementConverter::Convert(high_nibble); + out_data[i] = Int4ElementConverter::Convert(val); } } }; @@ -523,12 +520,11 @@ struct TensorCaster(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { - auto low_nibble = in_data[i].GetElem(0); - auto high_nibble = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[2 * i] = static_cast(low_nibble); - out_data[2 * i + 1] = static_cast(high_nibble); + out_data[i] = static_cast(val); } } }; @@ -539,12 +535,11 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { - auto low_nibble = in_data[i].GetElem(0); - auto high_nibble = in_data[i].GetElem(1); + for (size_t i = 0; i < narrow(shape.Size()); ++i) { + // elem 0 is the low nibble, 1 the high nibble + auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[2 * i] = low_nibble != 0; - out_data[2 * i + 1] = high_nibble != 0; + out_data[i] = val != 0; } } }; @@ -555,13 +550,14 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()) >> 1; ++i) { + for (size_t i = 0; i < narrow(shape.Size() + 1) >> 1; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); // Convert to signed by clamping to int4 range (-8 to 7) - int8_t high_signed = std::clamp(static_cast(high_nibble), int8_t(-8), int8_t(7)); int8_t low_signed = std::clamp(static_cast(low_nibble), int8_t(-8), int8_t(7)); + int8_t high_signed = std::clamp(static_cast(high_nibble), int8_t(-8), int8_t(7)); + out_data[i] = Int4x2(low_signed, high_signed); } } @@ -583,11 +579,13 @@ struct TensorCaster::ConvertToInt4(in_data[i]); int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + out_data[i >> 1] = Int4x2(low_val, 0); } } @@ -604,11 +602,13 @@ struct TensorCaster { for (; i < shape_size - 1; i += 2) { int8_t low_val = in_data[i] ? 1 : 0; int8_t high_val = in_data[i + 1] ? 1 : 0; + out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { int8_t low_val = in_data[i] ? 1 : 0; + out_data[i >> 1] = Int4x2(low_val, 0); } } @@ -630,11 +630,13 @@ struct TensorCaster::ConvertToUInt4(in_data[i]); uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + out_data[i >> 1] = UInt4x2(low_val, 0); } } @@ -651,11 +653,13 @@ struct TensorCaster { for (; i < shape_size - 1; i += 2) { uint8_t low_val = in_data[i] ? 1 : 0; uint8_t high_val = in_data[i + 1] ? 1 : 0; + out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { uint8_t low_val = in_data[i] ? 1 : 0; + out_data[i >> 1] = UInt4x2(low_val, 0); } } diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 0c3c6ccfebbde..f74d3e0cbfc51 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -306,6 +306,21 @@ TEST(CastOpTest, Int4x2ToUInt32) { TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); } +TEST(CastOpTest, Int4x2ToInt32OddNumberOfElements) { + // GIVEN + const std::vector odd_shape{5}; + const std::vector odd_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, 0), + }; + + const std::vector expected_odd_output = {-8, 7, 0, -1, 3}; + + // WHEN, THEN + TestCastOp(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape); +} + TEST(CastOpTest, Int4x2ToInt64) { // GIVEN const std::vector shape{2, 2, 2}; From b39a4f43164a34f209a52a79a0d54696fc6dcba6 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 12 Jun 2025 08:51:57 -0700 Subject: [PATCH 35/88] add specialization from float to MLFloat16 --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 7e65965611a69..9dd0f4f9098ab 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -424,6 +424,17 @@ struct TensorCaster { } }; +// tensor float -> MLFloat16 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + auto in_data = in.Data(); + auto out_data = out.MutableData(); + const size_t shape_size = narrow(shape.Size()); + MlasConvertFloatToHalfBuffer(in_data, out_data, shape_size); + } +}; + template struct TensorCaster::value || IsOrtFloat16Type::value From 351626ef1957a60cf6572aaa1369db3ade2f5d55 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 12 Jun 2025 10:22:20 -0700 Subject: [PATCH 36/88] try to fix pipeline errors --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 9dd0f4f9098ab..c4ac3a35bee5d 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -262,7 +262,7 @@ struct ToInt4ElementConverter { template struct ToInt4ElementConverter + std::enable_if_t::value #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif @@ -576,7 +576,7 @@ struct TensorCaster { template struct TensorCaster::value || IsStandardFloatType::value || std::is_same_v + std::enable_if_t::value || IsStandardFloatType::value || IsOrtFloat16Type::value #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif @@ -627,7 +627,7 @@ struct TensorCaster { template struct TensorCaster::value || IsStandardFloatType::value || std::is_same_v + std::enable_if_t::value || IsStandardFloatType::value || IsOrtFloat16Type::value #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif @@ -698,7 +698,8 @@ void CastMLFloat16ThroughFloatTensor( // tensor MLFloat16 -> X template -struct TensorCaster { +struct TensorCaster && !std::is_same_v>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& in, Tensor& out) const { CastMLFloat16ThroughFloatTensor(context, shape, in, out); } From 3b84c142e890ac839085b9fa194270c3e26c511f Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 12 Jun 2025 10:27:21 -0700 Subject: [PATCH 37/88] lint --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index c4ac3a35bee5d..2dedaa4fb33d8 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -699,7 +699,7 @@ void CastMLFloat16ThroughFloatTensor( // tensor MLFloat16 -> X template struct TensorCaster && !std::is_same_v>> { + std::enable_if_t && !std::is_same_v>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& in, Tensor& out) const { CastMLFloat16ThroughFloatTensor(context, shape, in, out); } From b5ecbf9ae49377b264971c7766f77c3335c73c61 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 12 Jun 2025 12:18:49 -0700 Subject: [PATCH 38/88] Try [[noreturn]] to fix pipelines --- include/onnxruntime/core/framework/data_types_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 05f4c10995ef2..92506828c38de 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -338,7 +338,7 @@ class CallableDispatchableHelper { // Other policies may set the second result argument accordingly. template struct UnsupportedTypeDefaultPolicy { - void operator()(int32_t dt_type, Ret& /*result*/) const { + [[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { ORT_THROW("Unsupported data type: ", dt_type); } }; From b19df98acf437113f12a7bfdde30f6190a190855 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Thu, 12 Jun 2025 12:46:11 -0700 Subject: [PATCH 39/88] supress 4702 warning --- include/onnxruntime/core/framework/data_types_internal.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 92506828c38de..7c8c3698ac32b 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -338,8 +338,15 @@ class CallableDispatchableHelper { // Other policies may set the second result argument accordingly. template struct UnsupportedTypeDefaultPolicy { - [[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { + void operator()(int32_t dt_type, Ret& /*result*/) const { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4702) +#endif ORT_THROW("Unsupported data type: ", dt_type); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif } }; From 77b1916bda01150d3cef2b570b45797a2f4e5bed Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Fri, 13 Jun 2025 13:27:05 -0700 Subject: [PATCH 40/88] try pipeline fix --- include/onnxruntime/core/framework/data_types_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 7c8c3698ac32b..e4e7799019280 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -338,7 +338,7 @@ class CallableDispatchableHelper { // Other policies may set the second result argument accordingly. template struct UnsupportedTypeDefaultPolicy { - void operator()(int32_t dt_type, Ret& /*result*/) const { +[[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4702) From 838195682359fe8010feb465feb1d20ab3eee94b Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 16 Jun 2025 09:41:10 -0700 Subject: [PATCH 41/88] lint --- include/onnxruntime/core/framework/data_types_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index e4e7799019280..bd1c54657e33f 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -338,7 +338,7 @@ class CallableDispatchableHelper { // Other policies may set the second result argument accordingly. template struct UnsupportedTypeDefaultPolicy { -[[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { + [[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4702) From 3d9be2fe68b58052fe77cd525e77969afe9ea1e7 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 16 Jun 2025 19:38:21 -0700 Subject: [PATCH 42/88] disable warning --- .../core/framework/data_types_internal.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index bd1c54657e33f..fa5804c3b924b 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -324,7 +324,14 @@ class CallableDispatchableHelper { int Invoke(Fn&& fn, Args&&... args) { if (utils::ToTensorProtoElementType() == dt_type_) { std::forward(fn)(std::forward(args)...); +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4702) +#endif ++called_; +#if defined(_MSC_VER) +#pragma warning(pop) +#endif } return 0; } @@ -339,14 +346,7 @@ class CallableDispatchableHelper { template struct UnsupportedTypeDefaultPolicy { [[noreturn]] void operator()(int32_t dt_type, Ret& /*result*/) const { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4702) -#endif ORT_THROW("Unsupported data type: ", dt_type); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif } }; From 9ce545aa98e8b710678cae7593411bb03b7c0abf Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 16 Jun 2025 19:57:26 -0700 Subject: [PATCH 43/88] move pragma statements --- include/onnxruntime/core/framework/data_types_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fa5804c3b924b..b166e272360be 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -323,11 +323,11 @@ class CallableDispatchableHelper { template int Invoke(Fn&& fn, Args&&... args) { if (utils::ToTensorProtoElementType() == dt_type_) { - std::forward(fn)(std::forward(args)...); #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4702) #endif + std::forward(fn)(std::forward(args)...); ++called_; #if defined(_MSC_VER) #pragma warning(pop) From 082ae757abbf769d909b0c22440a9892795a6647 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 16 Jun 2025 20:58:05 -0700 Subject: [PATCH 44/88] update pragma --- .../core/framework/data_types_internal.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index b166e272360be..4cc57ba4b5391 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -319,22 +319,22 @@ class CallableDispatchableHelper { public: explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {} - // Must return integer to be in a expandable context - template - int Invoke(Fn&& fn, Args&&... args) { - if (utils::ToTensorProtoElementType() == dt_type_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4702) #endif + // Must return integer to be in a expandable context + template + int Invoke(Fn&& fn, Args&&... args) { + if (utils::ToTensorProtoElementType() == dt_type_) { std::forward(fn)(std::forward(args)...); ++called_; -#if defined(_MSC_VER) -#pragma warning(pop) -#endif } return 0; } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); From c70e7f3a5f45c88c121d366d0456146ab0d69004 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 11:47:35 -0700 Subject: [PATCH 45/88] Update docs for Cast --- docs/OperatorKernels.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5154c334acc23..6884d72522884 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -58,11 +58,11 @@ Do not modify directly.* |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)| From 950a8e32e54fadcb99e7685235febb3885a42e76 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 16:49:31 -0700 Subject: [PATCH 46/88] Update onnx patch --- cmake/patches/onnx/onnx.patch | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 30d5a44a1d1cc..fa7f0f64919bd 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -200,3 +200,30 @@ index 0aab3e26..27f32195 100644 + + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/backend/test/case/node/cast.py b/onnx/backend/test/case/node/cast.py +index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b077bc80e 100644 +--- a/onnx/backend/test/case/node/cast.py ++++ b/onnx/backend/test/case/node/cast.py +@@ -59,10 +59,11 @@ class Cast(Base): + test_cases = [ +- ("FLOAT", "UINT4"), +- ("FLOAT16", "UINT4"), +- ("FLOAT", "INT4"), +- ("FLOAT16", "INT4"), +- ("UINT4", "FLOAT"), +- ("UINT4", "FLOAT16"), +- ("UINT4", "UINT8"), +- ("INT4", "FLOAT"), +- ("INT4", "FLOAT16"), +- ("INT4", "INT8"), ++ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 ++ # ("FLOAT", "UINT4"), ++ # ("FLOAT16", "UINT4"), ++ # ("FLOAT", "INT4"), ++ # ("FLOAT16", "INT4"), ++ # ("UINT4", "FLOAT"), ++ # ("UINT4", "FLOAT16"), ++ # ("UINT4", "UINT8"), ++ # ("INT4", "FLOAT"), ++ # ("INT4", "FLOAT16"), ++ # ("INT4", "INT8"), \ No newline at end of file From 41af71b6b336bab074744f81b67507fec40df0f1 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 17:06:34 -0700 Subject: [PATCH 47/88] update patch --- cmake/patches/onnx/onnx.patch | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index fa7f0f64919bd..43b4840845cbf 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -204,8 +204,8 @@ diff --git a/onnx/backend/test/case/node/cast.py b/onnx/backend/test/case/node/c index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b077bc80e 100644 --- a/onnx/backend/test/case/node/cast.py +++ b/onnx/backend/test/case/node/cast.py -@@ -59,10 +59,11 @@ class Cast(Base): - test_cases = [ +@@ -59,10 +59,11 @@ + ("FLOAT8E5M2FNUZ", "FLOAT16"), - ("FLOAT", "UINT4"), - ("FLOAT16", "UINT4"), - ("FLOAT", "INT4"), @@ -226,4 +226,5 @@ index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b + # ("UINT4", "UINT8"), + # ("INT4", "FLOAT"), + # ("INT4", "FLOAT16"), -+ # ("INT4", "INT8"), \ No newline at end of file ++ # ("INT4", "INT8"), + ("FLOAT4E2M1", "FLOAT"), \ No newline at end of file From 5b79476350bbcdc78fee2c4a3f8922b16a890d10 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 17:22:20 -0700 Subject: [PATCH 48/88] update patch --- cmake/patches/onnx/onnx.patch | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 43b4840845cbf..4b30c1941db33 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -204,7 +204,7 @@ diff --git a/onnx/backend/test/case/node/cast.py b/onnx/backend/test/case/node/c index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b077bc80e 100644 --- a/onnx/backend/test/case/node/cast.py +++ b/onnx/backend/test/case/node/cast.py -@@ -59,10 +59,11 @@ +@@ -58,12 +58,13 @@ ("FLOAT8E5M2FNUZ", "FLOAT16"), - ("FLOAT", "UINT4"), - ("FLOAT16", "UINT4"), @@ -227,4 +227,4 @@ index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b + # ("INT4", "FLOAT"), + # ("INT4", "FLOAT16"), + # ("INT4", "INT8"), - ("FLOAT4E2M1", "FLOAT"), \ No newline at end of file + ("FLOAT4E2M1", "FLOAT"), From b39c660a8dda37a3f83af1c3a4284323f2166816 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 18:44:17 -0700 Subject: [PATCH 49/88] keep binskim.patch in sync with onnx.patch --- cmake/vcpkg-ports/onnx/binskim.patch | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index 30d5a44a1d1cc..4b30c1941db33 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -200,3 +200,31 @@ index 0aab3e26..27f32195 100644 + + #endif // ! ONNX_ONNX_PB_H +diff --git a/onnx/backend/test/case/node/cast.py b/onnx/backend/test/case/node/cast.py +index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b077bc80e 100644 +--- a/onnx/backend/test/case/node/cast.py ++++ b/onnx/backend/test/case/node/cast.py +@@ -58,12 +58,13 @@ + ("FLOAT8E5M2FNUZ", "FLOAT16"), +- ("FLOAT", "UINT4"), +- ("FLOAT16", "UINT4"), +- ("FLOAT", "INT4"), +- ("FLOAT16", "INT4"), +- ("UINT4", "FLOAT"), +- ("UINT4", "FLOAT16"), +- ("UINT4", "UINT8"), +- ("INT4", "FLOAT"), +- ("INT4", "FLOAT16"), +- ("INT4", "INT8"), ++ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 ++ # ("FLOAT", "UINT4"), ++ # ("FLOAT16", "UINT4"), ++ # ("FLOAT", "INT4"), ++ # ("FLOAT16", "INT4"), ++ # ("UINT4", "FLOAT"), ++ # ("UINT4", "FLOAT16"), ++ # ("UINT4", "UINT8"), ++ # ("INT4", "FLOAT"), ++ # ("INT4", "FLOAT16"), ++ # ("INT4", "INT8"), + ("FLOAT4E2M1", "FLOAT"), From 10b2eea2e27cb57c87ab31d8aba2fe1dfe277eb4 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 19:43:34 -0700 Subject: [PATCH 50/88] update patches --- cmake/patches/onnx/onnx.patch | 10 ++++++++++ cmake/vcpkg-ports/onnx/binskim.patch | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 4b30c1941db33..ff5bc9b2e1d10 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -228,3 +228,13 @@ index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b + # ("INT4", "FLOAT16"), + # ("INT4", "INT8"), ("FLOAT4E2M1", "FLOAT"), +@@ -232,7 +233,8 @@ + expected_tensor = make_tensor( + "x", getattr(TensorProto, to_type), [3, 5], expected.tolist() + ) + output = expected_tensor +- elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"): ++ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 ++ elif False and (from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4")): + np_fp32 = np.arange(-9, 16).astype(np.float32) + input_shape = (5, 5) diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index 4b30c1941db33..ff5bc9b2e1d10 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -228,3 +228,13 @@ index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b + # ("INT4", "FLOAT16"), + # ("INT4", "INT8"), ("FLOAT4E2M1", "FLOAT"), +@@ -232,7 +233,8 @@ + expected_tensor = make_tensor( + "x", getattr(TensorProto, to_type), [3, 5], expected.tolist() + ) + output = expected_tensor +- elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"): ++ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 ++ elif False and (from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4")): + np_fp32 = np.arange(-9, 16).astype(np.float32) + input_shape = (5, 5) From c89e11f34bd065d5c4e29c62291e3da0948116b2 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 25 Jun 2025 20:22:48 -0700 Subject: [PATCH 51/88] exclude onnx tests in TestCase.cc --- onnxruntime/test/onnx/TestCase.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 18e82b529a147..c05aa8f8bf431 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -948,6 +948,16 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"slice_neg_steps", "Type parameter (Tind) bound to different types (tensor(int64) and tensor(int32) in node ()."}, {"cast_BFLOAT16_to_FLOAT", "Unexpected input data type"}, + {"cast_FLOAT16_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_FLOAT16_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_FLOAT_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_FLOAT_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_INT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_INT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_INT4_to_INT8", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_UINT8", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, {"loop13_seq", "Creation of empty sequences is currently not supported in the test runner"}, {"sequence_insert_at_front", "shape mismatch, expect {4} got {3}"}, {"cast_FLOAT_to_BFLOAT16", "expect uint16 got bfloat16"}, From da7e4448efbb41effafec16b1846447a8c9562f4 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 30 Jun 2025 14:09:39 -0700 Subject: [PATCH 52/88] remove patch fixes --- cmake/patches/onnx/onnx.patch | 89 ++-------------------------- cmake/vcpkg-ports/onnx/binskim.patch | 89 ++-------------------------- 2 files changed, 8 insertions(+), 170 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index ff5bc9b2e1d10..922deaf7d85b1 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 8b5af303..7fe05a5a 100644 +index 8b5af303..8593fe4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -47,15 +47,7 @@ index 8b5af303..7fe05a5a 100644 add_library(onnx_proto ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) add_dependencies(onnx_proto gen_onnx_operators_proto gen_onnx_data_proto) -@@ -492,6 +507,7 @@ if(MSVC) - endif() - else() - # On non-Windows, hide all symbols we don't need -+ set(EXTRA_FLAGS "-Wno-unused-parameter") - set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") - set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) - set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -595,13 +611,6 @@ if(ONNX_BUILD_PYTHON) +@@ -595,13 +610,6 @@ if(ONNX_BUILD_PYTHON) target_link_libraries(onnx_cpp2py_export PRIVATE ${Python3_LIBRARIES}) target_compile_options(onnx_cpp2py_export PRIVATE /MP @@ -69,7 +61,7 @@ index 8b5af303..7fe05a5a 100644 ${EXTRA_FLAGS}) add_msvc_runtime_flag(onnx_cpp2py_export) add_onnx_global_defines(onnx_cpp2py_export) -@@ -618,23 +627,9 @@ endif() +@@ -618,23 +626,9 @@ endif() if(MSVC) target_compile_options(onnx_proto PRIVATE /MP @@ -164,77 +156,4 @@ index acf3aac7..5bef6e72 100644 + OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } - static void -diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..27f32195 100644 ---- a/onnx/onnx_pb.h -+++ b/onnx/onnx_pb.h -@@ -47,10 +47,30 @@ - #define ONNX_API ONNX_IMPORT - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic push -+ -+// In file included from onnx/onnx-ml.pb.h:30: -+// In file included from google/protobuf/extension_set.h:53: -+// google/protobuf/parse_context.h:328:47: error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] -+#if defined(__has_warning) -+#if __has_warning("-Wshorten-64-to-32") -+#pragma GCC diagnostic ignored "-Wshorten-64-to-32" -+#endif -+#endif // defined(__has_warning) -+ -+#endif // defined(__GNUC__) -+ -+ - #ifdef ONNX_ML - #include "onnx/onnx-ml.pb.h" - #else - #include "onnx/onnx.pb.h" - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic pop -+#endif -+ -+ - #endif // ! ONNX_ONNX_PB_H -diff --git a/onnx/backend/test/case/node/cast.py b/onnx/backend/test/case/node/cast.py -index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b077bc80e 100644 ---- a/onnx/backend/test/case/node/cast.py -+++ b/onnx/backend/test/case/node/cast.py -@@ -58,12 +58,13 @@ - ("FLOAT8E5M2FNUZ", "FLOAT16"), -- ("FLOAT", "UINT4"), -- ("FLOAT16", "UINT4"), -- ("FLOAT", "INT4"), -- ("FLOAT16", "INT4"), -- ("UINT4", "FLOAT"), -- ("UINT4", "FLOAT16"), -- ("UINT4", "UINT8"), -- ("INT4", "FLOAT"), -- ("INT4", "FLOAT16"), -- ("INT4", "INT8"), -+ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 -+ # ("FLOAT", "UINT4"), -+ # ("FLOAT16", "UINT4"), -+ # ("FLOAT", "INT4"), -+ # ("FLOAT16", "INT4"), -+ # ("UINT4", "FLOAT"), -+ # ("UINT4", "FLOAT16"), -+ # ("UINT4", "UINT8"), -+ # ("INT4", "FLOAT"), -+ # ("INT4", "FLOAT16"), -+ # ("INT4", "INT8"), - ("FLOAT4E2M1", "FLOAT"), -@@ -232,7 +233,8 @@ - expected_tensor = make_tensor( - "x", getattr(TensorProto, to_type), [3, 5], expected.tolist() - ) - output = expected_tensor -- elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"): -+ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 -+ elif False and (from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4")): - np_fp32 = np.arange(-9, 16).astype(np.float32) - input_shape = (5, 5) + static void \ No newline at end of file diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index ff5bc9b2e1d10..922deaf7d85b1 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 8b5af303..7fe05a5a 100644 +index 8b5af303..8593fe4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -47,15 +47,7 @@ index 8b5af303..7fe05a5a 100644 add_library(onnx_proto ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}) add_dependencies(onnx_proto gen_onnx_operators_proto gen_onnx_data_proto) -@@ -492,6 +507,7 @@ if(MSVC) - endif() - else() - # On non-Windows, hide all symbols we don't need -+ set(EXTRA_FLAGS "-Wno-unused-parameter") - set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") - set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) - set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -595,13 +611,6 @@ if(ONNX_BUILD_PYTHON) +@@ -595,13 +610,6 @@ if(ONNX_BUILD_PYTHON) target_link_libraries(onnx_cpp2py_export PRIVATE ${Python3_LIBRARIES}) target_compile_options(onnx_cpp2py_export PRIVATE /MP @@ -69,7 +61,7 @@ index 8b5af303..7fe05a5a 100644 ${EXTRA_FLAGS}) add_msvc_runtime_flag(onnx_cpp2py_export) add_onnx_global_defines(onnx_cpp2py_export) -@@ -618,23 +627,9 @@ endif() +@@ -618,23 +626,9 @@ endif() if(MSVC) target_compile_options(onnx_proto PRIVATE /MP @@ -164,77 +156,4 @@ index acf3aac7..5bef6e72 100644 + OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } - static void -diff --git a/onnx/onnx_pb.h b/onnx/onnx_pb.h -index 0aab3e26..27f32195 100644 ---- a/onnx/onnx_pb.h -+++ b/onnx/onnx_pb.h -@@ -47,10 +47,30 @@ - #define ONNX_API ONNX_IMPORT - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic push -+ -+// In file included from onnx/onnx-ml.pb.h:30: -+// In file included from google/protobuf/extension_set.h:53: -+// google/protobuf/parse_context.h:328:47: error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] -+#if defined(__has_warning) -+#if __has_warning("-Wshorten-64-to-32") -+#pragma GCC diagnostic ignored "-Wshorten-64-to-32" -+#endif -+#endif // defined(__has_warning) -+ -+#endif // defined(__GNUC__) -+ -+ - #ifdef ONNX_ML - #include "onnx/onnx-ml.pb.h" - #else - #include "onnx/onnx.pb.h" - #endif - -+#if defined(__GNUC__) -+#pragma GCC diagnostic pop -+#endif -+ -+ - #endif // ! ONNX_ONNX_PB_H -diff --git a/onnx/backend/test/case/node/cast.py b/onnx/backend/test/case/node/cast.py -index 9696373920d72e782948cf9b5cf137983d229814..226f57eee4701dc8fefbf5d64983385b077bc80e 100644 ---- a/onnx/backend/test/case/node/cast.py -+++ b/onnx/backend/test/case/node/cast.py -@@ -58,12 +58,13 @@ - ("FLOAT8E5M2FNUZ", "FLOAT16"), -- ("FLOAT", "UINT4"), -- ("FLOAT16", "UINT4"), -- ("FLOAT", "INT4"), -- ("FLOAT16", "INT4"), -- ("UINT4", "FLOAT"), -- ("UINT4", "FLOAT16"), -- ("UINT4", "UINT8"), -- ("INT4", "FLOAT"), -- ("INT4", "FLOAT16"), -- ("INT4", "INT8"), -+ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 -+ # ("FLOAT", "UINT4"), -+ # ("FLOAT16", "UINT4"), -+ # ("FLOAT", "INT4"), -+ # ("FLOAT16", "INT4"), -+ # ("UINT4", "FLOAT"), -+ # ("UINT4", "FLOAT16"), -+ # ("UINT4", "UINT8"), -+ # ("INT4", "FLOAT"), -+ # ("INT4", "FLOAT16"), -+ # ("INT4", "INT8"), - ("FLOAT4E2M1", "FLOAT"), -@@ -232,7 +233,8 @@ - expected_tensor = make_tensor( - "x", getattr(TensorProto, to_type), [3, 5], expected.tolist() - ) - output = expected_tensor -- elif from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4"): -+ # Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074 -+ elif False and (from_type in ("UINT4", "INT4") or to_type in ("UINT4", "INT4")): - np_fp32 = np.arange(-9, 16).astype(np.float32) - input_shape = (5, 5) + static void \ No newline at end of file From be31bdd2162f8372ad1fca6f992bb967b5d3aa44 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 30 Jun 2025 14:17:53 -0700 Subject: [PATCH 53/88] Add newline at end of patch files --- cmake/patches/onnx/onnx.patch | 2 +- cmake/vcpkg-ports/onnx/binskim.patch | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 922deaf7d85b1..f51370212ff5a 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -156,4 +156,4 @@ index acf3aac7..5bef6e72 100644 + OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } - static void \ No newline at end of file + static void diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index 922deaf7d85b1..f51370212ff5a 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -156,4 +156,4 @@ index acf3aac7..5bef6e72 100644 + OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } - static void \ No newline at end of file + static void From ea2956c3c03245136587f556c20267162308d24d Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 1 Jul 2025 16:31:55 -0700 Subject: [PATCH 54/88] explicitly mention next onnx version in skipped tests --- onnxruntime/test/onnx/TestCase.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index d1d9c2b6abe8f..647c947f37f0c 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -948,16 +948,16 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"slice_neg_steps", "Type parameter (Tind) bound to different types (tensor(int64) and tensor(int32) in node ()."}, {"cast_BFLOAT16_to_FLOAT", "Unexpected input data type"}, - {"cast_FLOAT16_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_FLOAT16_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_FLOAT_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_FLOAT_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_INT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_INT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_INT4_to_INT8", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_UINT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_UINT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, - {"cast_UINT4_to_UINT8", "Skipped until onnxruntime/cmake/external/onnx points to a version of onnx which includes @onnx/onnx/pull/7074"}, + {"cast_FLOAT16_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_FLOAT16_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_FLOAT_to_INT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_FLOAT_to_UINT4", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_INT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_INT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_INT4_to_INT8", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_FLOAT", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_FLOAT16", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, + {"cast_UINT4_to_UINT8", "Skipped until onnxruntime/cmake/external/onnx points to onnx 1.19 which should include @onnx/onnx/pull/7074"}, {"loop13_seq", "Creation of empty sequences is currently not supported in the test runner"}, {"sequence_insert_at_front", "shape mismatch, expect {4} got {3}"}, {"cast_FLOAT_to_BFLOAT16", "expect uint16 got bfloat16"}, From 989719a1b45e148f2aeff1bbaccdb5a9610f4591 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 1 Jul 2025 16:37:12 -0700 Subject: [PATCH 55/88] use std::is_floating_point_v --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 2dedaa4fb33d8..96ca4ca814f72 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -73,13 +73,6 @@ struct IsStandardIntegerType { std::is_same_v; }; -template -struct IsStandardFloatType { - static constexpr bool value = - std::is_same_v || - std::is_same_v; -}; - // string cast helpers // Note: when C++17 is available, use functions @@ -457,7 +450,7 @@ struct TensorCaster struct TensorCaster::value>> { + std::enable_if_t>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -526,7 +519,7 @@ struct TensorCaster -struct TensorCaster::value>> { +struct TensorCaster>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); @@ -576,7 +569,7 @@ struct TensorCaster { template struct TensorCaster::value || IsStandardFloatType::value || IsOrtFloat16Type::value + std::enable_if_t::value || std::is_floating_point_v || IsOrtFloat16Type::value #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif @@ -627,7 +620,7 @@ struct TensorCaster { template struct TensorCaster::value || IsStandardFloatType::value || IsOrtFloat16Type::value + std::enable_if_t::value || std::is_floating_point_v || IsOrtFloat16Type::value #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif From e2be244262de3b3788d7363a2a3f4777029a9cc5 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 1 Jul 2025 18:48:38 -0700 Subject: [PATCH 56/88] use constants for min anx max (u)int4 values --- .../core/providers/cpu/tensor/cast_op.cc | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 96ca4ca814f72..213c20577ab69 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -191,8 +191,14 @@ struct EigenCastType { using type = Eigen::bfloat16; }; -// Helper struct for converting from Int4x2/UInt4x2 elements to any destination type namespace { + +constexpr int INT4_MIN = -8; +constexpr int INT4_MAX = 7; +constexpr unsigned int UINT4_MIN = 0; +constexpr unsigned int UINT4_MAX = 15; + +// Helper struct for converting from Int4x2/UInt4x2 elements to any destination type template struct Int4ElementConverter { static DstType Convert(int8_t val) { @@ -215,41 +221,41 @@ template struct ToInt4ElementConverter { // Default implementation for most numeric types static int8_t ConvertToInt4(const SrcType& val) { - int8_t result = static_cast(val); + int result = static_cast(val); // Clamp to int4 range (-8 to 7) - return std::clamp(result, int8_t(-8), int8_t(7)); + return static_cast(std::clamp(result, INT4_MIN, INT4_MAX)); } static uint8_t ConvertToUInt4(const SrcType& val) { - uint8_t result = static_cast(val); + unsigned int result = static_cast(val); // Clamp to uint4 range (0 to 15) - return std::min(result, uint8_t(15)); + return static_cast(std::min(result, UINT4_MAX)); } }; template <> struct ToInt4ElementConverter { static int8_t ConvertToInt4(const float& val) { - int8_t result = static_cast(std::roundf(val)); - return std::clamp(result, static_cast(-8), static_cast(7)); + int result = static_cast(std::roundf(val)); + return static_cast(std::clamp(result, INT4_MIN, INT4_MAX)); } static uint8_t ConvertToUInt4(const float& val) { - uint8_t result = static_cast(std::max(0.0f, std::roundf(val))); - return std::min(result, static_cast(15)); + unsigned int result = static_cast(std::max(0.0f, std::roundf(val))); + return static_cast(std::min(result, UINT4_MAX)); } }; template <> struct ToInt4ElementConverter { static int8_t ConvertToInt4(const double& val) { - int8_t result = static_cast(std::round(val)); - return std::clamp(result, static_cast(-8), static_cast(7)); + int result = static_cast(std::round(val)); + return static_cast(std::clamp(result, INT4_MIN, INT4_MAX)); } static uint8_t ConvertToUInt4(const double& val) { - uint8_t result = static_cast(std::max(0.0, std::round(val))); - return std::min(result, static_cast(15)); + unsigned int result = static_cast(std::max(0.0, std::round(val))); + return static_cast(std::min(result, UINT4_MAX)); } }; @@ -360,8 +366,8 @@ struct TensorCaster { // Parse each string and clamp to int4 range (-8 to 7) int v0 = std::stoi(in_data[i]); int v1 = std::stoi(in_data[i + 1]); - int8_t val0 = static_cast(std::clamp(v0, -8, 7)); - int8_t val1 = static_cast(std::clamp(v1, -8, 7)); + int8_t val0 = static_cast(std::clamp(v0, INT4_MIN, INT4_MAX)); + int8_t val1 = static_cast(std::clamp(v1, INT4_MIN, INT4_MAX)); out_data[i >> 1] = Int4x2(val0, val1); } @@ -369,7 +375,7 @@ struct TensorCaster { // Handle odd number of elements - pad with 0 if (i < shape_size) { int v0 = std::stoi(in_data[i]); - int8_t val0 = static_cast(std::clamp(v0, -8, 7)); + int8_t val0 = static_cast(std::clamp(v0, INT4_MIN, INT4_MAX)); out_data[i >> 1] = Int4x2(val0, 0); } @@ -390,8 +396,8 @@ struct TensorCaster { // Parse each string and clamp to uint4 range (0 to 15) int v0 = std::stoi(in_data[i]); int v1 = std::stoi(in_data[i + 1]); - uint8_t val0 = static_cast(std::clamp(v0, 0, 15)); - uint8_t val1 = static_cast(std::clamp(v1, 0, 15)); + uint8_t val0 = static_cast(std::clamp(v0, static_cast(UINT4_MIN), static_cast(UINT4_MAX))); + uint8_t val1 = static_cast(std::clamp(v1, static_cast(UINT4_MIN), static_cast(UINT4_MAX))); out_data[i >> 1] = UInt4x2(val0, val1); } @@ -399,7 +405,7 @@ struct TensorCaster { // Handle odd number of elements - pad with 0 if (i < shape_size) { int v0 = std::stoi(in_data[i]); - uint8_t val0 = static_cast(std::clamp(v0, 0, 15)); + uint8_t val0 = static_cast(std::clamp(v0, static_cast(UINT4_MIN), static_cast(UINT4_MAX))); out_data[i >> 1] = UInt4x2(val0, 0); } @@ -559,8 +565,8 @@ struct TensorCaster { auto high_nibble = in_data[i].GetElem(1); // Convert to signed by clamping to int4 range (-8 to 7) - int8_t low_signed = std::clamp(static_cast(low_nibble), int8_t(-8), int8_t(7)); - int8_t high_signed = std::clamp(static_cast(high_nibble), int8_t(-8), int8_t(7)); + int8_t low_signed = static_cast(std::clamp(static_cast(low_nibble), INT4_MIN, INT4_MAX)); + int8_t high_signed = static_cast(std::clamp(static_cast(high_nibble), INT4_MIN, INT4_MAX)); out_data[i] = Int4x2(low_signed, high_signed); } From c7c05faa0d36344f3a24c39f6ffc3edea6f94a71 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 1 Jul 2025 19:28:37 -0700 Subject: [PATCH 57/88] use constexpr if to merge specializations --- .../core/providers/cpu/tensor/cast_op.cc | 73 +++++++++---------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 213c20577ab69..226ec10244017 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -322,12 +322,13 @@ struct TensorCaster { template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); // Unpack each Int4x2 into two separate string elements size_t out_idx = 0; - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); @@ -339,12 +340,13 @@ struct TensorCaster { template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); // Unpack each UInt4x2 into two separate string elements size_t out_idx = 0; - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); @@ -356,12 +358,12 @@ struct TensorCaster { template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); // Every 2 strings combine into 1 Int4x2 - const size_t shape_size = narrow(shape.Size()); - size_t i = 0; + ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { // Parse each string and clamp to int4 range (-8 to 7) int v0 = std::stoi(in_data[i]); @@ -386,12 +388,12 @@ struct TensorCaster { template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); // Every 2 strings combine into 1 UInt4x2 - const size_t shape_size = narrow(shape.Size()); - size_t i = 0; + ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { // Parse each string and clamp to uint4 range (0 to 15) int v0 = std::stoi(in_data[i]); @@ -442,10 +444,11 @@ struct TensorCaster> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); @@ -458,10 +461,11 @@ template struct TensorCaster>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); @@ -473,10 +477,11 @@ struct TensorCaster struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); @@ -488,10 +493,11 @@ struct TensorCaster { template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size() + 1) >> 1; const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size() + 1) >> 1; ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -506,35 +512,26 @@ struct TensorCaster { template struct TensorCaster::value || IsOrtFloat16Type::value + std::enable_if_t::value || IsOrtFloat16Type::value || std::is_floating_point_v #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif >> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); out_data[i] = Int4ElementConverter::Convert(val); - } - } -}; - -template -struct TensorCaster>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - for (size_t i = 0; i < narrow(shape.Size()); ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - - out_data[i] = static_cast(val); + if constexpr (std::is_floating_point_v) { + out_data[i] = static_cast(val); + } else { + out_data[i] = Int4ElementConverter::Convert(val); + } } } }; @@ -542,10 +539,11 @@ struct TensorCaster struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size()); ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); @@ -557,10 +555,11 @@ struct TensorCaster { template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size() + 1) >> 1; const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - for (size_t i = 0; i < narrow(shape.Size() + 1) >> 1; ++i) { + for (ptrdiff_t i = 0; i < shape_size; ++i) { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); @@ -581,11 +580,11 @@ struct TensorCaster> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - size_t i = 0; + ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); @@ -604,11 +603,11 @@ struct TensorCaster struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - size_t i = 0; + ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { int8_t low_val = in_data[i] ? 1 : 0; int8_t high_val = in_data[i + 1] ? 1 : 0; @@ -632,11 +631,11 @@ struct TensorCaster> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - size_t i = 0; + ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); @@ -655,11 +654,11 @@ struct TensorCaster struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - const size_t shape_size = narrow(shape.Size()); - size_t i = 0; + ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { uint8_t low_val = in_data[i] ? 1 : 0; uint8_t high_val = in_data[i + 1] ? 1 : 0; From 393219987a67703887f84099bfd2a648ff1360b5 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 1 Jul 2025 19:35:29 -0700 Subject: [PATCH 58/88] use constexpr if to merge specializations for Int4 --- .../core/providers/cpu/tensor/cast_op.cc | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 226ec10244017..675d4dfad9b35 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -438,7 +438,7 @@ struct TensorCaster { template struct TensorCaster::value || IsOrtFloat16Type::value + std::enable_if_t::value || IsOrtFloat16Type::value || std::is_floating_point_v #if !defined(DISABLE_FLOAT8_TYPES) || IsOrtFloat8Type::value #endif @@ -452,24 +452,11 @@ struct TensorCaster> 1].GetElem(i & 0x1); - out_data[i] = Int4ElementConverter::Convert(val); - } - } -}; - -template -struct TensorCaster>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - for (ptrdiff_t i = 0; i < shape_size; ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - - out_data[i] = static_cast(val); + if constexpr (std::is_floating_point_v) { + out_data[i] = static_cast(val); + } else { + out_data[i] = Int4ElementConverter::Convert(val); + } } } }; From 5e00b26f23cb9e067d8e3d63e1d39af1d4df53db Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 1 Jul 2025 19:41:46 -0700 Subject: [PATCH 59/88] remove extra line --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 675d4dfad9b35..447fb2ef50820 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -513,7 +513,6 @@ struct TensorCaster> 1].GetElem(i & 0x1); - out_data[i] = Int4ElementConverter::Convert(val); if constexpr (std::is_floating_point_v) { out_data[i] = static_cast(val); } else { From af00c7df610440a8cb41c7ca1d4e7b7c31e0a253 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 7 Jul 2025 06:57:42 -0700 Subject: [PATCH 60/88] remove anonymous namespace --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 447fb2ef50820..b6e4032ca3ba5 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -191,8 +191,6 @@ struct EigenCastType { using type = Eigen::bfloat16; }; -namespace { - constexpr int INT4_MIN = -8; constexpr int INT4_MAX = 7; constexpr unsigned int UINT4_MIN = 0; @@ -275,7 +273,6 @@ struct ToInt4ElementConverter Y template From fe03e23116213c75db1a8d128066c479a85e8ead Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 7 Jul 2025 07:29:06 -0700 Subject: [PATCH 61/88] update IsOrtFloat8Type usage --- .../core/providers/cpu/tensor/cast_op.cc | 50 ++++--------------- 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index b6e4032ca3ba5..4bb755e47ac78 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -58,6 +58,9 @@ using IsOrtFloat16Type = boost::mp11::mp_contains, #if !defined(DISABLE_FLOAT8_TYPES) template using IsOrtFloat8Type = boost::mp11::mp_contains; +#else +template +struct IsOrtFloat8Type : std::false_type {}; #endif template @@ -128,11 +131,7 @@ CastToString(const SrcType& input, std::string& output) { } template -#if !defined(DISABLE_FLOAT8_TYPES) typename std::enable_if::value || IsOrtFloat8Type::value, void>::type -#else -typename std::enable_if::value>::type -#endif CastToString(const SrcType& input, std::string& output) { CastToString(static_cast(input), output); } @@ -162,11 +161,7 @@ CastFromString(const std::string& input, DstType& output) { } template -#if !defined(DISABLE_FLOAT8_TYPES) typename std::enable_if::value || IsOrtFloat8Type::value, void>::type -#else -typename std::enable_if::value, void>::type -#endif CastFromString(const std::string& input, DstType& output) { float intermediate; CastFromString(input, intermediate); @@ -202,13 +197,9 @@ struct Int4ElementConverter { static DstType Convert(int8_t val) { if constexpr (IsOrtFloat16Type::value) { return DstType(static_cast(val)); - } -#if !defined(DISABLE_FLOAT8_TYPES) - else if constexpr (IsOrtFloat8Type::value) { + } else if constexpr (IsOrtFloat8Type::value) { return DstType(static_cast(val), true); - } -#endif - else { + } else { return static_cast(val); } } @@ -259,11 +250,7 @@ struct ToInt4ElementConverter { template struct ToInt4ElementConverter::value -#if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value -#endif - >> { + std::enable_if_t::value || IsOrtFloat8Type::value>> { static int8_t ConvertToInt4(const SrcType& val) { return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); } @@ -273,7 +260,6 @@ struct ToInt4ElementConverter Y template struct TensorCaster { @@ -435,11 +421,7 @@ struct TensorCaster { template struct TensorCaster::value || IsOrtFloat16Type::value || std::is_floating_point_v -#if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value -#endif - >> { + std::enable_if_t::value || IsOrtFloat16Type::value || std::is_floating_point_v || IsOrtFloat8Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -496,11 +478,7 @@ struct TensorCaster { template struct TensorCaster::value || IsOrtFloat16Type::value || std::is_floating_point_v -#if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value -#endif - >> { + std::enable_if_t::value || IsOrtFloat16Type::value || std::is_floating_point_v || IsOrtFloat8Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -557,11 +535,7 @@ struct TensorCaster { template struct TensorCaster::value || std::is_floating_point_v || IsOrtFloat16Type::value -#if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value -#endif - >> { + std::enable_if_t::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -608,11 +582,7 @@ struct TensorCaster { template struct TensorCaster::value || std::is_floating_point_v || IsOrtFloat16Type::value -#if !defined(DISABLE_FLOAT8_TYPES) - || IsOrtFloat8Type::value -#endif - >> { + std::enable_if_t::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); From d382f2c95c20e004f1f0c1bfa2b7908fba456775 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 7 Jul 2025 14:18:17 -0700 Subject: [PATCH 62/88] update cast between int4x2 and uint4x2 and tests, add Int4x2ToUInt64 test --- .../core/providers/cpu/tensor/cast_op.cc | 10 ++-- .../test/providers/cpu/tensor/cast_op_test.cc | 49 +++++++++++++------ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 4bb755e47ac78..d01f882f82236 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -467,9 +467,8 @@ struct TensorCaster { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); - // Convert to unsigned by clamping at 0 - uint8_t low_unsigned = static_cast(std::max(0, static_cast(low_nibble)) & 0x0F); - uint8_t high_unsigned = static_cast(std::max(0, static_cast(high_nibble)) & 0x0F); + uint8_t low_unsigned = static_cast(low_nibble) & 0x0F; + uint8_t high_unsigned = static_cast(high_nibble) & 0x0F; out_data[i] = UInt4x2(low_unsigned, high_unsigned); } @@ -524,9 +523,8 @@ struct TensorCaster { auto low_nibble = in_data[i].GetElem(0); auto high_nibble = in_data[i].GetElem(1); - // Convert to signed by clamping to int4 range (-8 to 7) - int8_t low_signed = static_cast(std::clamp(static_cast(low_nibble), INT4_MIN, INT4_MAX)); - int8_t high_signed = static_cast(std::clamp(static_cast(high_nibble), INT4_MIN, INT4_MAX)); + int8_t low_signed = static_cast((low_nibble & 0x0F) << 4) >> 4; + int8_t high_signed = static_cast((high_nibble & 0x0F) << 4) >> 4; out_data[i] = Int4x2(low_signed, high_signed); } diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index f74d3e0cbfc51..5720446e850c0 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -337,6 +337,23 @@ TEST(CastOpTest, Int4x2ToInt64) { TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_int64_output), shape); } +TEST(CastOpTest, Int4x2ToUInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), // boundary values + Int4x2(0, -1), // zero and negative + Int4x2(3, -5), // positive and negative + Int4x2(6, 2) // both positive + }; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint32_output = {18446744073709551608, 7, 0, 18446744073709551615, 3, 18446744073709551611, 6, 2}; + + // WHEN, THEN + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); +} + TEST(CastOpTest, UInt4x2ToUInt8) { // GIVEN const std::vector shape{2, 2, 2}; @@ -681,16 +698,16 @@ TEST(CastOpTest, Int4x2ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; const std::vector int4x2_input = { - Int4x2(-8, 7), // negative values get clamped to 0 - Int4x2(0, -1), // -1 becomes 0 - Int4x2(3, -5), // -5 becomes 0 - Int4x2(6, 2) // positive values remain + Int4x2(-8, 7), // negative values + Int4x2(0, -1), // -1 becomes max unsigned value + Int4x2(3, -5), // positive and negative values + Int4x2(6, 2) // positive values }; const std::vector expected_uint4x2_output = { - UInt4x2(0, 7), // -8 clamped to 0 - UInt4x2(0, 0), // -1 clamped to 0 - UInt4x2(3, 0), // -5 clamped to 0 + UInt4x2(8, 7), // -8 becomes 8 + UInt4x2(0, 15), // -1 becomes 15 + UInt4x2(3, 11), // -5 becomes 11 UInt4x2(6, 2) // unchanged }; @@ -702,17 +719,17 @@ TEST(CastOpTest, UInt4x2ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; const std::vector uint4x2_input = { - UInt4x2(0, 15), // 15 is out of int4 range, should be clamped to 7 - UInt4x2(1, 14), // 14 is out of int4 range, should be clamped to 7 - UInt4x2(7, 8), // 8 is out of int4 range, should be clamped to 7 + UInt4x2(0, 15), // 15 is out of int4 range + UInt4x2(1, 14), // 14 is out of int4 range + UInt4x2(7, 8), // 8 is out of int4 range UInt4x2(3, 6) // both within range }; const std::vector expected_int4x2_output = { - Int4x2(0, 7), // 15 clamped to 7 - Int4x2(1, 7), // 14 clamped to 7 - Int4x2(7, 7), // 8 clamped to 7 - Int4x2(3, 6) // unchanged + Int4x2(0, -1), // 15 becomes -1 + Int4x2(1, -2), // 14 becomes -2 + Int4x2(7, -8), // 8 becomes -8 + Int4x2(3, 6) // unchanged }; // WHEN, THEN @@ -722,14 +739,14 @@ TEST(CastOpTest, UInt4x2ToInt4x2) { TEST(CastOpTest, Int8ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int8_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + const std::vector int8_input = {-10, 15, 0, -1, 3, -5, -128, 127}; // values outside int4 range get clamped const std::vector expected_int4x2_output = { Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 Int4x2(0, -1), Int4x2(3, -5), - Int4x2(6, 2)}; + Int4x2(-8, 7)}; // WHEN, THEN TestCastOp(gsl::make_span(int8_input), gsl::make_span(expected_int4x2_output), shape); From 82b1fb5c7de4f68d8ef9cea2f661c2a9dbcb300b Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 7 Jul 2025 16:51:31 -0700 Subject: [PATCH 63/88] Update cast down implementation and tests --- .../core/providers/cpu/tensor/cast_op.cc | 182 +++++----- .../test/providers/cpu/tensor/cast_op_test.cc | 321 +++++++++++------- 2 files changed, 311 insertions(+), 192 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index d01f882f82236..2ade7c370cd1f 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -192,71 +192,104 @@ constexpr unsigned int UINT4_MIN = 0; constexpr unsigned int UINT4_MAX = 15; // Helper struct for converting from Int4x2/UInt4x2 elements to any destination type -template +template struct Int4ElementConverter { - static DstType Convert(int8_t val) { - if constexpr (IsOrtFloat16Type::value) { - return DstType(static_cast(val)); - } else if constexpr (IsOrtFloat8Type::value) { - return DstType(static_cast(val), true); - } else { - return static_cast(val); - } - } -}; - -// Helper struct for converting from any type to Int4/UInt4 elements -template -struct ToInt4ElementConverter { - // Default implementation for most numeric types static int8_t ConvertToInt4(const SrcType& val) { - int result = static_cast(val); - // Clamp to int4 range (-8 to 7) - return static_cast(std::clamp(result, INT4_MIN, INT4_MAX)); + // Truncate to 4 bits and sign-extend properly + uint8_t truncated = static_cast(val) & 0x0F; + // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement + return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } static uint8_t ConvertToUInt4(const SrcType& val) { - unsigned int result = static_cast(val); - // Clamp to uint4 range (0 to 15) - return static_cast(std::min(result, UINT4_MAX)); + // Truncate to 4 bits + return static_cast(val) & 0x0F; + } + + static SrcType Convert(int8_t val) { + if constexpr (IsOrtFloat16Type::value) { + return SrcType(static_cast(val)); + } else if constexpr (IsOrtFloat8Type::value) { + return SrcType(static_cast(val), true); + } else { + return static_cast(val); + } } }; template <> -struct ToInt4ElementConverter { +struct Int4ElementConverter { static int8_t ConvertToInt4(const float& val) { int result = static_cast(std::roundf(val)); - return static_cast(std::clamp(result, INT4_MIN, INT4_MAX)); + uint8_t truncated = static_cast(result) & 0x0F; + return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } static uint8_t ConvertToUInt4(const float& val) { - unsigned int result = static_cast(std::max(0.0f, std::roundf(val))); - return static_cast(std::min(result, UINT4_MAX)); + int result = static_cast(std::roundf(val)); + return static_cast(result) & 0x0F; + } + + static float Convert(int8_t val) { + return static_cast(val); } }; template <> -struct ToInt4ElementConverter { +struct Int4ElementConverter { static int8_t ConvertToInt4(const double& val) { int result = static_cast(std::round(val)); - return static_cast(std::clamp(result, INT4_MIN, INT4_MAX)); + uint8_t truncated = static_cast(result) & 0x0F; + return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } static uint8_t ConvertToUInt4(const double& val) { - unsigned int result = static_cast(std::max(0.0, std::round(val))); - return static_cast(std::min(result, UINT4_MAX)); + int result = static_cast(std::round(val)); + return static_cast(result) & 0x0F; + } + + static double Convert(int8_t val) { + return static_cast(val); } }; -template -struct ToInt4ElementConverter::value || IsOrtFloat8Type::value>> { - static int8_t ConvertToInt4(const SrcType& val) { - return ToInt4ElementConverter::ConvertToInt4(static_cast(val)); +template <> +struct Int4ElementConverter { + static int8_t ConvertToInt4(const MLFloat16& val) { + float f_val = static_cast(val); + int result = static_cast(std::roundf(f_val)); + uint8_t truncated = static_cast(result) & 0x0F; + return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } - static uint8_t ConvertToUInt4(const SrcType& val) { - return ToInt4ElementConverter::ConvertToUInt4(static_cast(val)); + static uint8_t ConvertToUInt4(const MLFloat16& val) { + float f_val = static_cast(val); + int result = static_cast(std::roundf(f_val)); + return static_cast(result) & 0x0F; + } + + static MLFloat16 Convert(int8_t val) { + return MLFloat16(static_cast(val)); + } +}; + +template <> +struct Int4ElementConverter { + static int8_t ConvertToInt4(const BFloat16& val) { + float f_val = static_cast(val); + int result = static_cast(std::roundf(f_val)); + uint8_t truncated = static_cast(result) & 0x0F; + return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); + } + + static uint8_t ConvertToUInt4(const BFloat16& val) { + float f_val = static_cast(val); + int result = static_cast(std::roundf(f_val)); + return static_cast(result) & 0x0F; + } + + static BFloat16 Convert(int8_t val) { + return BFloat16(static_cast(val)); } }; @@ -346,28 +379,26 @@ struct TensorCaster { auto* out_data = out.MutableData(); // Every 2 strings combine into 1 Int4x2 - ptrdiff_t i = 0; - for (; i < shape_size - 1; i += 2) { - // Parse each string and clamp to int4 range (-8 to 7) - int v0 = std::stoi(in_data[i]); - int v1 = std::stoi(in_data[i + 1]); - int8_t val0 = static_cast(std::clamp(v0, INT4_MIN, INT4_MAX)); - int8_t val1 = static_cast(std::clamp(v1, INT4_MIN, INT4_MAX)); - - out_data[i >> 1] = Int4x2(val0, val1); - } - - // Handle odd number of elements - pad with 0 - if (i < shape_size) { - int v0 = std::stoi(in_data[i]); - int8_t val0 = static_cast(std::clamp(v0, INT4_MIN, INT4_MAX)); + const ptrdiff_t out_size = (shape_size + 1) >> 1; + for (ptrdiff_t i = 0; i < out_size; ++i) { + const ptrdiff_t in_idx = i << 1; + + // Parse first value and truncate to lower 4 bits with sign extension + int v0 = std::stoi(in_data[in_idx]); + int8_t val0 = static_cast((v0 & 0xF) | (-(v0 & 0x8) & 0xF0)); + + // Parse second value (or use 0 if odd number of elements) + int8_t val1 = 0; + if (in_idx + 1 < shape_size) { + int v1 = std::stoi(in_data[in_idx + 1]); + val1 = static_cast((v1 & 0xF) | (-(v1 & 0x8) & 0xF0)); + } - out_data[i >> 1] = Int4x2(val0, 0); + out_data[i] = Int4x2(val0, val1); } } }; -// TensorCaster specialization for string to UInt4x2 template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -376,23 +407,22 @@ struct TensorCaster { auto* out_data = out.MutableData(); // Every 2 strings combine into 1 UInt4x2 - ptrdiff_t i = 0; - for (; i < shape_size - 1; i += 2) { - // Parse each string and clamp to uint4 range (0 to 15) - int v0 = std::stoi(in_data[i]); - int v1 = std::stoi(in_data[i + 1]); - uint8_t val0 = static_cast(std::clamp(v0, static_cast(UINT4_MIN), static_cast(UINT4_MAX))); - uint8_t val1 = static_cast(std::clamp(v1, static_cast(UINT4_MIN), static_cast(UINT4_MAX))); - - out_data[i >> 1] = UInt4x2(val0, val1); - } - - // Handle odd number of elements - pad with 0 - if (i < shape_size) { - int v0 = std::stoi(in_data[i]); - uint8_t val0 = static_cast(std::clamp(v0, static_cast(UINT4_MIN), static_cast(UINT4_MAX))); + const ptrdiff_t out_size = (shape_size + 1) >> 1; + for (ptrdiff_t i = 0; i < out_size; ++i) { + const ptrdiff_t in_idx = i << 1; + + // Parse first value and truncate to lower 4 bits + int v0 = std::stoi(in_data[in_idx]); + uint8_t val0 = static_cast(v0 & 0xF); + + // Parse second value (or use 0 if odd number of elements) + uint8_t val1 = 0; + if (in_idx + 1 < shape_size) { + int v1 = std::stoi(in_data[in_idx + 1]); + val1 = static_cast(v1 & 0xF); + } - out_data[i >> 1] = UInt4x2(val0, 0); + out_data[i] = UInt4x2(val0, val1); } } }; @@ -541,14 +571,14 @@ struct TensorCaster::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = Int4ElementConverter::ConvertToInt4(in_data[i]); + int8_t high_val = Int4ElementConverter::ConvertToInt4(in_data[i + 1]); out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { - int8_t low_val = ToInt4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = Int4ElementConverter::ConvertToInt4(in_data[i]); out_data[i >> 1] = Int4x2(low_val, 0); } @@ -588,14 +618,14 @@ struct TensorCaster::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = Int4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t high_val = Int4ElementConverter::ConvertToUInt4(in_data[i + 1]); out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { - uint8_t low_val = ToInt4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = Int4ElementConverter::ConvertToUInt4(in_data[i]); out_data[i >> 1] = UInt4x2(low_val, 0); } diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 5720446e850c0..694666f98e3e6 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -300,7 +300,7 @@ TEST(CastOpTest, Int4x2ToUInt32) { }; // Negative values will be cast to their unsigned representation - const std::vector expected_uint32_output = {4294967288, 7, 0, 4294967295, 3, 4294967291, 6, 2}; + const std::vector expected_uint32_output = {4294967288, 7, 0, UINT32_MAX, 3, 4294967291, 6, 2}; // WHEN, THEN TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); @@ -348,7 +348,7 @@ TEST(CastOpTest, Int4x2ToUInt64) { }; // Negative values will be cast to their unsigned representation - const std::vector expected_uint32_output = {18446744073709551608, 7, 0, 18446744073709551615, 3, 18446744073709551611, 6, 2}; + const std::vector expected_uint32_output = {18446744073709551608, 7, 0, UINT64_MAX, 3, 18446744073709551611, 6, 2}; // WHEN, THEN TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); @@ -705,10 +705,10 @@ TEST(CastOpTest, Int4x2ToUInt4x2) { }; const std::vector expected_uint4x2_output = { - UInt4x2(8, 7), // -8 becomes 8 - UInt4x2(0, 15), // -1 becomes 15 - UInt4x2(3, 11), // -5 becomes 11 - UInt4x2(6, 2) // unchanged + UInt4x2(8, 7), // -8 becomes 8 + UInt4x2(0, 15), // -1 becomes 15 + UInt4x2(3, 11), // -5 becomes 11 + UInt4x2(6, 2) // unchanged }; // WHEN, THEN @@ -741,12 +741,12 @@ TEST(CastOpTest, Int8ToInt4x2) { const std::vector shape{2, 2, 2}; const std::vector int8_input = {-10, 15, 0, -1, 3, -5, -128, 127}; - // values outside int4 range get clamped const std::vector expected_int4x2_output = { - Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(-8, 7)}; + Int4x2(6, -1), // -10 truncated to 6, 15 truncated to -1 + Int4x2(0, -1), // 0 unchanged, -1 unchanged + Int4x2(3, -5), // 3 unchanged, -5 unchanged + Int4x2(0, -1) // -128 truncated to 0, 127 truncated to -1 + }; // WHEN, THEN TestCastOp(gsl::make_span(int8_input), gsl::make_span(expected_int4x2_output), shape); @@ -755,14 +755,15 @@ TEST(CastOpTest, Int8ToInt4x2) { TEST(CastOpTest, UInt8ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint8_input = {20, 15, 0, 1, 7, 25, 3, 12}; + const std::vector uint8_input = {20, 255, 0, 17, 7, 240, 15, 31}; - // values outside uint4 range get clamped + // values get truncated to lower 4 bits const std::vector expected_uint4x2_output = { - UInt4x2(15, 15), // 20 clamped to 15 - UInt4x2(0, 1), - UInt4x2(7, 15), // 25 clamped to 15 - UInt4x2(3, 12)}; + UInt4x2(4, 15), // 20 (0x14) truncated to 4, 255 (0xFF) truncated to 15 + UInt4x2(0, 1), // 0 (0x00) truncated to 0, 17 (0x11) truncated to 1 + UInt4x2(7, 0), // 7 (0x07) truncated to 7, 240 (0xF0) truncated to 0 + UInt4x2(15, 15) // 15 (0x0F) truncated to 15, 31 (0x1F) truncated to 15 + }; // WHEN, THEN TestCastOp(gsl::make_span(uint8_input), gsl::make_span(expected_uint4x2_output), shape); @@ -771,14 +772,15 @@ TEST(CastOpTest, UInt8ToUInt4x2) { TEST(CastOpTest, Int16ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int16_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + const std::vector int16_input = {-10, 32767, 0, -32768, 3, -5, 240, 31}; - // values outside int4 range get clamped + // values get truncated to lower 4 bits and sign-extended const std::vector expected_int4x2_output = { - Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; + Int4x2(6, -1), // -10 (0xFFF6) truncated to 6, 32767 (0x7FFF) truncated to -1 + Int4x2(0, 0), // 0 (0x0000) truncated to 0, -32768 (0x8000) truncated to 0 + Int4x2(3, -5), // 3 (0x0003) truncated to 3, -5 (0xFFFB) truncated to -5 + Int4x2(0, -1) // 240 (0x00F0) truncated to 0, 31 (0x001F) truncated to -1 + }; // WHEN, THEN TestCastOp(gsl::make_span(int16_input), gsl::make_span(expected_int4x2_output), shape); @@ -787,14 +789,15 @@ TEST(CastOpTest, Int16ToInt4x2) { TEST(CastOpTest, UInt16ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint16_input = {20, 15, 0, 1, 7, 25, 3, 12}; + const std::vector uint16_input = {20, 65535, 0, 256, 7, 240, 15, 4095}; - // values outside uint4 range get clamped + // values get truncated to lower 4 bits const std::vector expected_uint4x2_output = { - UInt4x2(15, 15), // 20 clamped to 15 - UInt4x2(0, 1), - UInt4x2(7, 15), // 25 clamped to 15 - UInt4x2(3, 12)}; + UInt4x2(4, 15), // 20 (0x0014) truncated to 4, 65535 (0xFFFF) truncated to 15 + UInt4x2(0, 0), // 0 (0x0000) truncated to 0, 256 (0x0100) truncated to 0 + UInt4x2(7, 0), // 7 (0x0007) truncated to 7, 240 (0x00F0) truncated to 0 + UInt4x2(15, 15) // 15 (0x000F) truncated to 15, 4095 (0x0FFF) truncated to 15 + }; // WHEN, THEN TestCastOp(gsl::make_span(uint16_input), gsl::make_span(expected_uint4x2_output), shape); @@ -803,14 +806,15 @@ TEST(CastOpTest, UInt16ToUInt4x2) { TEST(CastOpTest, Int32ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int32_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + const std::vector int32_input = {-10, INT32_MAX, 0, INT32_MIN, 3, -5, 4080, 287}; - // values outside int4 range get clamped + // values get truncated to lower 4 bits and sign-extended const std::vector expected_int4x2_output = { - Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; + Int4x2(6, -1), // -10 (0xFFFFFFF6) truncated to 6, 2147483647 (0x7FFFFFFF) truncated to -1 + Int4x2(0, 0), // 0 (0x00000000) truncated to 0, -2147483648 (0x80000000) truncated to 0 + Int4x2(3, -5), // 3 (0x00000003) truncated to 3, -5 (0xFFFFFFFB) truncated to -5 + Int4x2(0, -1) // 4080 (0x00000FF0) truncated to 0, 287 (0x0000011F) truncated to -1 + }; // WHEN, THEN TestCastOp(gsl::make_span(int32_input), gsl::make_span(expected_int4x2_output), shape); @@ -819,12 +823,12 @@ TEST(CastOpTest, Int32ToInt4x2) { TEST(CastOpTest, Int32ToInt4x2OddNumberOfElements) { // GIVEN const std::vector odd_shape{5}; - const std::vector odd_input = {-8, 7, 0, -1, 3}; + const std::vector odd_input = {-10, INT32_MAX, 0, INT32_MIN, 4095}; const std::vector expected_odd_output = { - Int4x2(-8, 7), - Int4x2(0, -1), - Int4x2(3, 0) // last element paired with 0 + Int4x2(6, -1), // -10 truncated to 6, 2147483647 truncated to -1 + Int4x2(0, 0), // 0 truncated to 0, -2147483648 truncated to 0 + Int4x2(-1, 0) // 4095 truncated to -1, paired with 0 }; // WHEN, THEN @@ -844,14 +848,15 @@ TEST(CastOpTest, Int32ToInt4x2EmptyTensor) { TEST(CastOpTest, UInt32ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint32_input = {20, 15, 0, 1, 7, 25, 3, 12}; + const std::vector uint32_input = {20, UINT32_MAX, 0, 256, 7, 240, 15, 4095}; - // values outside uint4 range get clamped + // values get truncated to lower 4 bits const std::vector expected_uint4x2_output = { - UInt4x2(15, 15), // 20 clamped to 15 - UInt4x2(0, 1), - UInt4x2(7, 15), // 25 clamped to 15 - UInt4x2(3, 12)}; + UInt4x2(4, 15), // 20 truncated to 4, 4294967295 truncated to 15 + UInt4x2(0, 0), // 0 truncated to 0, 256 truncated to 0 + UInt4x2(7, 0), // 7 truncated to 7, 240 truncated to 0 + UInt4x2(15, 15) // 15 truncated to 15, 4095 truncated to 15 + }; // WHEN, THEN TestCastOp(gsl::make_span(uint32_input), gsl::make_span(expected_uint4x2_output), shape); @@ -860,14 +865,15 @@ TEST(CastOpTest, UInt32ToUInt4x2) { TEST(CastOpTest, Int64ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int64_input = {-10, 15, 0, -1, 3, -5, 6, 2}; + const std::vector int64_input = {-10, INT64_MAX, 0, INT64_MIN, 3, -5, 65520, 4111}; - // values outside int4 range get clamped + // values get truncated to lower 4 bits and sign-extended const std::vector expected_int4x2_output = { - Int4x2(-8, 7), // -10 clamped to -8, 15 clamped to 7 - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; + Int4x2(6, -1), // -10 truncated to 6, 9223372036854775807 truncated to -1 + Int4x2(0, 0), // 0 truncated to 0, -9223372036854775808 truncated to 0 + Int4x2(3, -5), // 3 truncated to 3, -5 truncated to -5 + Int4x2(0, -1) // 65520 truncated to 0, 4111 truncated to -1 + }; // WHEN, THEN TestCastOp(gsl::make_span(int64_input), gsl::make_span(expected_int4x2_output), shape); @@ -876,14 +882,15 @@ TEST(CastOpTest, Int64ToInt4x2) { TEST(CastOpTest, UInt64ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint64_input = {20, 15, 0, 1, 7, 25, 3, 12}; + const std::vector uint64_input = {20, UINT64_MAX, 0, 256, 7, 240, 15, 4095}; - // values outside uint4 range get clamped + // values get truncated to lower 4 bits const std::vector expected_uint4x2_output = { - UInt4x2(15, 15), // 20 clamped to 15 - UInt4x2(0, 1), - UInt4x2(7, 15), // 25 clamped to 15 - UInt4x2(3, 12)}; + UInt4x2(4, 15), // 20 truncated to 4, 18446744073709551615 truncated to 15 + UInt4x2(0, 0), // 0 truncated to 0, 256 truncated to 0 + UInt4x2(7, 0), // 7 truncated to 7, 240 truncated to 0 + UInt4x2(15, 15) // 15 truncated to 15, 4095 truncated to 15 + }; // WHEN, THEN TestCastOp(gsl::make_span(uint64_input), gsl::make_span(expected_uint4x2_output), shape); @@ -892,13 +899,13 @@ TEST(CastOpTest, UInt64ToUInt4x2) { TEST(CastOpTest, FloatToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector float_input = {-10.7f, 15.3f, 0.4f, -1.6f, 3.8f, -5.2f, 6.1f, 2.9f}; + const std::vector float_input = {-10.7f, 15.3f, 0.4f, -1.6f, 3.8f, -5.2f, 240.1f, 31.9f}; const std::vector expected_int4x2_output = { - Int4x2(-8, 7), // -10.7 rounded to -11, clamped to -8; 15.3 rounded to 15, clamped to 7 - Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 - Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 - Int4x2(6, 3) // 6.1 rounded to 6; 2.9 rounded to 3 + Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5, sign-extended to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 + Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 (0xFB), truncated to 11 (0x0B), sign-extended to -5 + Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 }; // WHEN, THEN @@ -908,13 +915,13 @@ TEST(CastOpTest, FloatToInt4x2) { TEST(CastOpTest, DoubleToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector double_input = {20.7, 15.3, 0.4, 1.6, 7.8, 25.2, 3.1, 12.9}; + const std::vector double_input = {20.7, 255.3, 0.4, 1.6, 7.8, 240.2, 15.1, 31.9}; const std::vector expected_uint4x2_output = { - UInt4x2(15, 15), // 20.7 rounded to 21, clamped to 15; 15.3 rounded to 15 - UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 - UInt4x2(8, 15), // 7.8 rounded to 8; 25.2 rounded to 25, clamped to 15 - UInt4x2(3, 13) // 3.1 rounded to 3; 12.9 rounded to 13 + UInt4x2(5, 15), // 20.7 rounded to 21, truncated to 5; 255.3 rounded to 255, truncated to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 0), // 7.8 rounded to 8; 240.2 rounded to 240, truncated to 0 + UInt4x2(15, 0) // 15.1 rounded to 15; 31.9 rounded to 32, truncated to 0 }; // WHEN, THEN @@ -925,20 +932,21 @@ TEST(CastOpTest, MLFloat16ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; const MLFloat16 mlfloat16_array[8] = { - MLFloat16(static_cast(-8)), - MLFloat16(static_cast(7)), - MLFloat16(static_cast(0)), - MLFloat16(static_cast(-1)), - MLFloat16(static_cast(3)), - MLFloat16(static_cast(-5)), - MLFloat16(static_cast(6)), - MLFloat16(static_cast(2))}; + MLFloat16(static_cast(-10.7f)), + MLFloat16(static_cast(15.3f)), + MLFloat16(static_cast(0.4f)), + MLFloat16(static_cast(-1.6f)), + MLFloat16(static_cast(3.8f)), + MLFloat16(static_cast(-5.2f)), + MLFloat16(static_cast(240.1f)), + MLFloat16(static_cast(31.9f))}; const std::vector expected_int4x2 = { - Int4x2(-8, 7), - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; + Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 + Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 (0xFB), truncated to 11 (0x0B), sign-extended to -5 + Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 + }; // WHEN, THEN TestCastOp( @@ -952,22 +960,23 @@ TEST(CastOpTest, MLFloat16ToUInt4x2) { // 8 MLFloat16 values will compress to 4 UInt4x2 values const std::vector shape{2, 4}; // Shape that contains 8 elements - // MLFloat16 values: 0, 15, 8, 7, 3, 12, 10, 5 + // MLFloat16 values with edge cases and truncation scenarios const MLFloat16 mlfloat16_array[8] = { - MLFloat16(static_cast(0)), - MLFloat16(static_cast(15)), - MLFloat16(static_cast(8)), - MLFloat16(static_cast(7)), - MLFloat16(static_cast(3)), - MLFloat16(static_cast(12)), - MLFloat16(static_cast(10)), - MLFloat16(static_cast(5))}; + MLFloat16(static_cast(20.7f)), + MLFloat16(static_cast(255.3f)), + MLFloat16(static_cast(0.4f)), + MLFloat16(static_cast(1.6f)), + MLFloat16(static_cast(7.8f)), + MLFloat16(static_cast(240.2f)), + MLFloat16(static_cast(15.1f)), + MLFloat16(static_cast(31.9f))}; const std::vector expected_uint4x2 = { - UInt4x2(0, 15), - UInt4x2(8, 7), - UInt4x2(3, 12), - UInt4x2(10, 5)}; + UInt4x2(5, 15), // 20.7 rounded to 21, truncated to 5; 255.3 rounded to 255, truncated to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 0), // 7.8 rounded to 8; 240.2 rounded to 240, truncated to 0 + UInt4x2(15, 0) // 15.1 rounded to 15; 31.9 rounded to 32, truncated to 0 + }; // WHEN, THEN TestCastOp( @@ -976,23 +985,23 @@ TEST(CastOpTest, MLFloat16ToUInt4x2) { shape); } -TEST(CastOpTest, MLFloat16ToInt4x2BoundaryValuesClamping) { +TEST(CastOpTest, MLFloat16ToInt4x2BoundaryValues) { // GIVEN - // Test MLFloat16 values that need clamping to Int4x2 range (-8 to 7) + // Test MLFloat16 values that need truncation to Int4x2 range const std::vector shape{3, 2}; const MLFloat16 mlfloat16_array[6] = { - MLFloat16(static_cast(-10)), // Below min, should clamp to -8 - MLFloat16(static_cast(9)), // Above max, should clamp to 7 - MLFloat16(static_cast(-8)), // At min, should remain -8 - MLFloat16(static_cast(7)), // At max, should remain 7 + MLFloat16(static_cast(-10)), // Truncated to lower 4 bits + MLFloat16(static_cast(9)), // Truncated to lower 4 bits + MLFloat16(static_cast(-8)), // Truncated to lower 4 bits + MLFloat16(static_cast(7)), // Truncated to lower 4 bits MLFloat16(static_cast(-0.6f)), // Should round to -1 MLFloat16(static_cast(1.7f)) // Should round to 2 }; - // Values should be clamped to int4 range (-8 to 7) + // Values get truncated to lower 4 bits and sign-extended const std::vector expected_int4x2 = { - Int4x2(-8, 7), // -10 clamped to -8, 9 clamped to 7 - Int4x2(-8, 7), // -8 and 7 already at boundaries + Int4x2(6, -7), // -10 (0xFFFFFFF6) truncated to 6, 9 (0x00000009) truncated to -7 + Int4x2(-8, 7), // -8 (0xFFFFFFF8) truncated to -8, 7 (0x00000007) truncated to 7 Int4x2(-1, 2) // -0.6 rounds to -1, 1.7 rounds to 2 }; @@ -1003,23 +1012,23 @@ TEST(CastOpTest, MLFloat16ToInt4x2BoundaryValuesClamping) { shape); } -TEST(CastOpTest, MLFloat16ToUInt4x2BoundaryValuesClamping) { +TEST(CastOpTest, MLFloat16ToUInt4x2BoundaryValues) { // GIVEN - // Test MLFloat16 values that need clamping to UInt4x2 range (0 to 15) + // Test MLFloat16 values that need truncation to UInt4x2 range const std::vector shape{3, 2}; // Shape that contains 6 elements const MLFloat16 mlfloat16_array[6] = { - MLFloat16(static_cast(-5)), // Negative, should clamp to 0 - MLFloat16(static_cast(20)), // Above max, should clamp to 15 + MLFloat16(static_cast(-5)), // Negative, truncated to lower 4 bits + MLFloat16(static_cast(20)), // Above max, truncated to lower 4 bits MLFloat16(static_cast(0)), // At min, should remain 0 MLFloat16(static_cast(15)), // At max, should remain 15 MLFloat16(static_cast(3.4f)), // Should round to 3 MLFloat16(static_cast(5.7f)) // Should round to 6 }; - // Values should be clamped to uint4 range (0-15) + // Values get truncated to lower 4 bits (no sign extension for unsigned) const std::vector expected_uint4x2 = { - UInt4x2(0, 15), // -5 clamped to 0, 20 clamped to 15 - UInt4x2(0, 15), // 0 and 15 already at boundaries + UInt4x2(11, 4), // -5 (0xFFFFFFFB) truncated to 11, 20 (0x00000014) truncated to 4 + UInt4x2(0, 15), // 0 and 15 already within range UInt4x2(3, 6) // 3.4 rounds to 3, 5.7 rounds to 6 }; @@ -1030,6 +1039,86 @@ TEST(CastOpTest, MLFloat16ToUInt4x2BoundaryValuesClamping) { shape); } +TEST(CastOpTest, BFloat16ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const BFloat16 bfloat16_array[8] = { + BFloat16(static_cast(-10.7f)), + BFloat16(static_cast(15.3f)), + BFloat16(static_cast(0.4f)), + BFloat16(static_cast(-1.6f)), + BFloat16(static_cast(3.8f)), + BFloat16(static_cast(-5.2f)), + BFloat16(static_cast(240.1f)), + BFloat16(static_cast(31.9f))}; + + const std::vector expected_int4x2 = { + Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 + Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 + Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 (0xFB), truncated to 11 (0x0B), sign-extended to -5 + Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(bfloat16_array, 8), + gsl::span(expected_int4x2), + shape); +} + +TEST(CastOpTest, BFloat16ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const BFloat16 bfloat16_array[8] = { + BFloat16(static_cast(20.7f)), + BFloat16(static_cast(255.3f)), + BFloat16(static_cast(0.4f)), + BFloat16(static_cast(1.6f)), + BFloat16(static_cast(7.8f)), + BFloat16(static_cast(240.2f)), + BFloat16(static_cast(15.1f)), + BFloat16(static_cast(31.9f))}; + + const std::vector expected_uint4x2 = { + UInt4x2(5, 15), // 20.7 rounded to 21, truncated to 5; 255.3 rounded to 255, truncated to 15 + UInt4x2(0, 2), // 0.4 rounded to 0; 1.6 rounded to 2 + UInt4x2(8, 0), // 7.8 rounded to 8; 240.2 rounded to 240, truncated to 0 + UInt4x2(15, 0) // 15.1 rounded to 15; 31.9 rounded to 32, truncated to 0 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(bfloat16_array, 8), + gsl::span(expected_uint4x2), + shape); +} + +TEST(CastOpTest, BFloat16ToUInt4x2BoundaryValues) { + // GIVEN + const std::vector shape{3, 2}; + const BFloat16 bfloat16_array[6] = { + BFloat16(static_cast(-5)), // Negative, truncated to lower 4 bits + BFloat16(static_cast(20)), // Above max, truncated to lower 4 bits + BFloat16(static_cast(0)), // At min, should remain 0 + BFloat16(static_cast(15)), // At max, should remain 15 + BFloat16(static_cast(3.4f)), // Should round to 3 + BFloat16(static_cast(5.7f)) // Should round to 6 + }; + + // Values get truncated to lower 4 bits (no clamping for consistency) + const std::vector expected_uint4x2 = { + UInt4x2(11, 4), // -5 (0xFFFFFFFB) truncated to 11, 20 (0x00000014) truncated to 4 + UInt4x2(0, 15), // 0 and 15 already within range + UInt4x2(3, 6) // 3.4 rounds to 3, 5.7 rounds to 6 + }; + + // WHEN, THEN + TestCastOp( + gsl::span(bfloat16_array, 6), + gsl::span(expected_uint4x2), + shape); +} + TEST(CastOpTest, BoolToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; @@ -1102,22 +1191,22 @@ TEST(CastOpTest, StringToUInt4x2) { TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); } -TEST(CastOpTest, String2UInt4x2BoundaryValuesClamping) { +TEST(CastOpTest, String2UInt4x2BoundaryValues) { // GIVEN - // Test string values that need clamping to UInt4x2 range (0-15) + // Test string values that need truncation to UInt4x2 range const std::vector shape{3, 2}; const std::vector string_input = { - "-5", "20", // out of range values that should be clamped - "16", "100", // out of range values that should be clamped + "-5", "20", // out of range values that get truncated + "16", "100", // out of range values that get truncated "0", "15" // boundary values that are in range }; // Each pair of strings becomes one UInt4x2 - // Values should be clamped to uint4 range (0-15) + // Values get truncated to lower 4 bits (no sign extension for unsigned) const std::vector expected_output{ - UInt4x2(0, 15), // -5 clamped to 0, 20 clamped to 15 - UInt4x2(15, 15), // 16 clamped to 15, 100 clamped to 15 - UInt4x2(0, 15) // 0 and 15 already in range + UInt4x2(11, 4), // -5 (0xFFFFFFFB) truncated to 11, 20 (0x00000014) truncated to 4 + UInt4x2(0, 4), // 16 (0x00000010) truncated to 0, 100 (0x00000064) truncated to 4 + UInt4x2(0, 15) // 0 and 15 already in range }; // WHEN, THEN From f1f4e2e8b2a2db9da965b895fcad549a64757078 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 7 Jul 2025 18:31:30 -0700 Subject: [PATCH 64/88] Fix pipeline issue --- onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 694666f98e3e6..7a5faa386be6d 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -348,10 +348,10 @@ TEST(CastOpTest, Int4x2ToUInt64) { }; // Negative values will be cast to their unsigned representation - const std::vector expected_uint32_output = {18446744073709551608, 7, 0, UINT64_MAX, 3, 18446744073709551611, 6, 2}; + const std::vector expected_uint64_output = {18446744073709551608ULL, 7, 0, UINT64_MAX, 3, 18446744073709551611ULL, 6, 2}; // WHEN, THEN - TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint32_output), shape); + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint64_output), shape); } TEST(CastOpTest, UInt4x2ToUInt8) { From 2b4c325c1658e02ccab173a30715da7b6da1936d Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 7 Jul 2025 18:57:37 -0700 Subject: [PATCH 65/88] fix pipeline issue --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 2ade7c370cd1f..9d836dcad93e5 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -186,11 +186,6 @@ struct EigenCastType { using type = Eigen::bfloat16; }; -constexpr int INT4_MIN = -8; -constexpr int INT4_MAX = 7; -constexpr unsigned int UINT4_MIN = 0; -constexpr unsigned int UINT4_MAX = 15; - // Helper struct for converting from Int4x2/UInt4x2 elements to any destination type template struct Int4ElementConverter { From 18521cc918483519170603622666fc9378b3088e Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 9 Jul 2025 16:45:35 -0700 Subject: [PATCH 66/88] debug pipeline issue --- .../test/providers/cpu/tensor/cast_op_test.cc | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 7a5faa386be6d..6ea32bd6df42d 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include @@ -1355,13 +1355,34 @@ TEST(CastOpTest, UInt4x2ToFloat8E5M2) { TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); } -TEST(CastOpTest, Float8E4M3FNToInt4x2) { +TEST(CastOpTest, Float8E4M3FNToInt4x2SaturateTrue) { // GIVEN const std::vector shape{2, 2, 2}; std::vector float8_input; const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; for (float val : input_values) { float8_input.emplace_back(Float8E4M3FN(val, true)); + std::cout << "Float8(" << val << ") saturate=true: " << static_cast(float8_input.back()) << std::endl; + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToInt4x2SaturateFalse) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, false)); + std::cout << "Float8(" << val << ") saturate=false: " << static_cast(float8_input.back()) << std::endl; } const std::vector expected_int4x2_output = { @@ -1380,7 +1401,8 @@ TEST(CastOpTest, Float8E4M3FNToUInt4x2) { std::vector uint_float8_input; const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; for (float val : uint_input_values) { - uint_float8_input.emplace_back(Float8E4M3FN(val, true)); + uint_float8_input.emplace_back(Float8E4M3FN(val, false)); + std::cout << "Float8(" << val << ") saturate=false: " << static_cast(uint_float8_input.back()) << std::endl; } const std::vector expected_uint4x2_output = { @@ -1393,6 +1415,16 @@ TEST(CastOpTest, Float8E4M3FNToUInt4x2) { TestCastOp(gsl::make_span(uint_float8_input), gsl::make_span(expected_uint4x2_output), shape); } +TEST(CastOpTest, DummyTestFloat8E4M3FNSaturation) { + Float8E4M3FN x_saturate(-8.0f, true); + std::cout << "Float8(-8.0f) saturate=true: " << static_cast(x_saturate) << std::endl; + EXPECT_EQ(static_cast(x_saturate), -8.0f); + + Float8E4M3FN x_no_saturate(-8.0f, false); + std::cout << "Float8(-8.0f) saturate=false: " << static_cast(x_no_saturate) << std::endl; + EXPECT_EQ(static_cast(x_no_saturate), -8.0f); +} + #endif } // namespace test From de3c01be0f758846b7b701c82b5256dc67a2ead2 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 9 Jul 2025 16:45:42 -0700 Subject: [PATCH 67/88] rename --- .../core/providers/cpu/tensor/cast_op.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 9d836dcad93e5..09ec00cf760f9 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -187,27 +187,27 @@ struct EigenCastType { }; // Helper struct for converting from Int4x2/UInt4x2 elements to any destination type -template +template struct Int4ElementConverter { - static int8_t ConvertToInt4(const SrcType& val) { + static int8_t ConvertToInt4(const OtherType& val) { // Truncate to 4 bits and sign-extend properly uint8_t truncated = static_cast(val) & 0x0F; // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } - static uint8_t ConvertToUInt4(const SrcType& val) { + static uint8_t ConvertToUInt4(const OtherType& val) { // Truncate to 4 bits return static_cast(val) & 0x0F; } - static SrcType Convert(int8_t val) { - if constexpr (IsOrtFloat16Type::value) { - return SrcType(static_cast(val)); - } else if constexpr (IsOrtFloat8Type::value) { - return SrcType(static_cast(val), true); + static OtherType Convert(int8_t val) { + if constexpr (IsOrtFloat16Type::value) { + return OtherType(static_cast(val)); + } else if constexpr (IsOrtFloat8Type::value) { + return OtherType(static_cast(val), true); } else { - return static_cast(val); + return static_cast(val); } } }; From 0dfb771eb9409d59ab2fb2a28743bec3e87b3bb9 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 9 Jul 2025 18:13:24 -0700 Subject: [PATCH 68/88] remove debugging tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 69 ------------------- 1 file changed, 69 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 6ea32bd6df42d..2ab2cbc1d52aa 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1355,75 +1355,6 @@ TEST(CastOpTest, UInt4x2ToFloat8E5M2) { TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); } -TEST(CastOpTest, Float8E4M3FNToInt4x2SaturateTrue) { - // GIVEN - const std::vector shape{2, 2, 2}; - std::vector float8_input; - const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; - for (float val : input_values) { - float8_input.emplace_back(Float8E4M3FN(val, true)); - std::cout << "Float8(" << val << ") saturate=true: " << static_cast(float8_input.back()) << std::endl; - } - - const std::vector expected_int4x2_output = { - Int4x2(-8, 7), - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; - - // WHEN, THEN - TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); -} - -TEST(CastOpTest, Float8E4M3FNToInt4x2SaturateFalse) { - // GIVEN - const std::vector shape{2, 2, 2}; - std::vector float8_input; - const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; - for (float val : input_values) { - float8_input.emplace_back(Float8E4M3FN(val, false)); - std::cout << "Float8(" << val << ") saturate=false: " << static_cast(float8_input.back()) << std::endl; - } - - const std::vector expected_int4x2_output = { - Int4x2(-8, 7), - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; - - // WHEN, THEN - TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); -} - -TEST(CastOpTest, Float8E4M3FNToUInt4x2) { - // GIVEN - const std::vector shape{2, 2, 2}; - std::vector uint_float8_input; - const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; - for (float val : uint_input_values) { - uint_float8_input.emplace_back(Float8E4M3FN(val, false)); - std::cout << "Float8(" << val << ") saturate=false: " << static_cast(uint_float8_input.back()) << std::endl; - } - - const std::vector expected_uint4x2_output = { - UInt4x2(0, 15), - UInt4x2(1, 14), - UInt4x2(7, 8), - UInt4x2(3, 12)}; - - // WHEN, THEN - TestCastOp(gsl::make_span(uint_float8_input), gsl::make_span(expected_uint4x2_output), shape); -} - -TEST(CastOpTest, DummyTestFloat8E4M3FNSaturation) { - Float8E4M3FN x_saturate(-8.0f, true); - std::cout << "Float8(-8.0f) saturate=true: " << static_cast(x_saturate) << std::endl; - EXPECT_EQ(static_cast(x_saturate), -8.0f); - - Float8E4M3FN x_no_saturate(-8.0f, false); - std::cout << "Float8(-8.0f) saturate=false: " << static_cast(x_no_saturate) << std::endl; - EXPECT_EQ(static_cast(x_no_saturate), -8.0f); -} #endif From 338c4408012ffa37e784ba7deaa51410d6bf082a Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 9 Jul 2025 18:31:24 -0700 Subject: [PATCH 69/88] lint --- onnxruntime/test/providers/cpu/tensor/cast_op_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 2ab2cbc1d52aa..c4b32b417df9f 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1355,7 +1355,6 @@ TEST(CastOpTest, UInt4x2ToFloat8E5M2) { TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); } - #endif } // namespace test From 6d76f96f7407dc1fa94f79c47587eec22c52081f Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 11:35:07 -0700 Subject: [PATCH 70/88] Update TensorCasterNoSat for Int4/UInt4, add tests with saturate = false --- .../core/providers/cpu/tensor/cast_op.cc | 22 +++++++++---------- .../test/providers/cpu/tensor/cast_op_test.cc | 16 ++++++++++++++ 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 09ec00cf760f9..8d7c1ab991ca8 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -690,6 +690,11 @@ struct TensorCaster { #if !defined(DISABLE_FLOAT8_TYPES) +// TensorCasterNoSat is only called when all the below conditions are met (see Cast::Compute implementation): +// - defined(DISABLE_FLOAT8_TYPES) == false +// - saturate_ == false +// - IsOrtFloat8Type::value == true + // tensor X -> float 8 template struct TensorCasterNoSat { @@ -704,17 +709,12 @@ struct TensorCasterNoSat { }; // TensorCasterNoSat should never be instantiated for Int4x2/UInt4x2 -template -struct TensorCasterNoSat { - void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { - ORT_THROW("Int4x2 should never use TensorCasterNoSat"); - } -}; - -template -struct TensorCasterNoSat { - void Cast(const OpKernelContext&, const TensorShape&, const Tensor&, Tensor&) const { - ORT_THROW("UInt4x2 should never use TensorCasterNoSat"); +template +struct TensorCasterNoSat || std::is_same_v>> { + void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) const { + // Int4x2/UInt4x2 always fit inside any Float8 type, so we can reuse the saturate = true implementation. + TensorCaster{}.Cast(context, shape, src, dst); } }; diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index c4b32b417df9f..c24913d71d72f 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1292,7 +1292,11 @@ TEST(CastOpTest, Int4x2ToFloat8E4M3FN) { } // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); } TEST(CastOpTest, UInt4x2ToFloat8E4M3FN) { @@ -1312,7 +1316,11 @@ TEST(CastOpTest, UInt4x2ToFloat8E4M3FN) { } // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); } TEST(CastOpTest, Int4x2ToFloat8E5M2) { @@ -1332,7 +1340,11 @@ TEST(CastOpTest, Int4x2ToFloat8E5M2) { } // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); } TEST(CastOpTest, UInt4x2ToFloat8E5M2) { @@ -1352,7 +1364,11 @@ TEST(CastOpTest, UInt4x2ToFloat8E5M2) { } // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); } #endif From 6cbeea4905ecc05c8e2271ee87eb46c40e002c7a Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 13:30:12 -0700 Subject: [PATCH 71/88] Add comments, extract common logic into lambda --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 8d7c1ab991ca8..a9cbc7132e832 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -189,6 +189,10 @@ struct EigenCastType { // Helper struct for converting from Int4x2/UInt4x2 elements to any destination type template struct Int4ElementConverter { + // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary + // Casting from fixed point to fixed point: when OOR, discard higher bits and + // reinterpret (with respect to two's complement representation for signed types). + // For example, 200 (int16) -> -56 (int8). static int8_t ConvertToInt4(const OtherType& val) { // Truncate to 4 bits and sign-extend properly uint8_t truncated = static_cast(val) & 0x0F; @@ -373,20 +377,23 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); + auto truncateToLower4BitsAndSignExtend = [](auto val) { + return (val & 0xF) | (-(val & 0x8) & 0xF0); + }; + // Every 2 strings combine into 1 Int4x2 const ptrdiff_t out_size = (shape_size + 1) >> 1; for (ptrdiff_t i = 0; i < out_size; ++i) { const ptrdiff_t in_idx = i << 1; - // Parse first value and truncate to lower 4 bits with sign extension int v0 = std::stoi(in_data[in_idx]); - int8_t val0 = static_cast((v0 & 0xF) | (-(v0 & 0x8) & 0xF0)); + int8_t val0 = static_cast(truncateToLower4BitsAndSignExtend(v0)); // Parse second value (or use 0 if odd number of elements) int8_t val1 = 0; if (in_idx + 1 < shape_size) { int v1 = std::stoi(in_data[in_idx + 1]); - val1 = static_cast((v1 & 0xF) | (-(v1 & 0x8) & 0xF0)); + val1 = static_cast(truncateToLower4BitsAndSignExtend(v1)); } out_data[i] = Int4x2(val0, val1); From ed92b066644bfd639f778f25256a9c046c96f113 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 13:56:58 -0700 Subject: [PATCH 72/88] extract FromInt4Converter into separate struct, rename converter --- .../core/providers/cpu/tensor/cast_op.cc | 75 +++++++------------ 1 file changed, 26 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index a9cbc7132e832..c9bb5311c8879 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -186,9 +186,22 @@ struct EigenCastType { using type = Eigen::bfloat16; }; +template +struct FromInt4Converter { + static DstType Convert(int8_t val) { + if constexpr (IsOrtFloat16Type::value) { + return DstType(static_cast(val)); + } else if constexpr (IsOrtFloat8Type::value) { + return DstType(static_cast(val), true); + } else { + return static_cast(val); + } + } +}; + // Helper struct for converting from Int4x2/UInt4x2 elements to any destination type template -struct Int4ElementConverter { +struct ToInt4Converter { // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary // Casting from fixed point to fixed point: when OOR, discard higher bits and // reinterpret (with respect to two's complement representation for signed types). @@ -204,20 +217,10 @@ struct Int4ElementConverter { // Truncate to 4 bits return static_cast(val) & 0x0F; } - - static OtherType Convert(int8_t val) { - if constexpr (IsOrtFloat16Type::value) { - return OtherType(static_cast(val)); - } else if constexpr (IsOrtFloat8Type::value) { - return OtherType(static_cast(val), true); - } else { - return static_cast(val); - } - } }; template <> -struct Int4ElementConverter { +struct ToInt4Converter { static int8_t ConvertToInt4(const float& val) { int result = static_cast(std::roundf(val)); uint8_t truncated = static_cast(result) & 0x0F; @@ -228,14 +231,10 @@ struct Int4ElementConverter { int result = static_cast(std::roundf(val)); return static_cast(result) & 0x0F; } - - static float Convert(int8_t val) { - return static_cast(val); - } }; template <> -struct Int4ElementConverter { +struct ToInt4Converter { static int8_t ConvertToInt4(const double& val) { int result = static_cast(std::round(val)); uint8_t truncated = static_cast(result) & 0x0F; @@ -246,14 +245,10 @@ struct Int4ElementConverter { int result = static_cast(std::round(val)); return static_cast(result) & 0x0F; } - - static double Convert(int8_t val) { - return static_cast(val); - } }; template <> -struct Int4ElementConverter { +struct ToInt4Converter { static int8_t ConvertToInt4(const MLFloat16& val) { float f_val = static_cast(val); int result = static_cast(std::roundf(f_val)); @@ -266,14 +261,10 @@ struct Int4ElementConverter { int result = static_cast(std::roundf(f_val)); return static_cast(result) & 0x0F; } - - static MLFloat16 Convert(int8_t val) { - return MLFloat16(static_cast(val)); - } }; template <> -struct Int4ElementConverter { +struct ToInt4Converter { static int8_t ConvertToInt4(const BFloat16& val) { float f_val = static_cast(val); int result = static_cast(std::roundf(f_val)); @@ -286,10 +277,6 @@ struct Int4ElementConverter { int result = static_cast(std::roundf(f_val)); return static_cast(result) & 0x0F; } - - static BFloat16 Convert(int8_t val) { - return BFloat16(static_cast(val)); - } }; // generic tensor X -> Y @@ -462,12 +449,7 @@ struct TensorCaster> 1].GetElem(i & 0x1); - - if constexpr (std::is_floating_point_v) { - out_data[i] = static_cast(val); - } else { - out_data[i] = Int4ElementConverter::Convert(val); - } + out_data[i] = FromInt4Converter::Convert(val); } } }; @@ -518,12 +500,7 @@ struct TensorCaster> 1].GetElem(i & 0x1); - - if constexpr (std::is_floating_point_v) { - out_data[i] = static_cast(val); - } else { - out_data[i] = Int4ElementConverter::Convert(val); - } + out_data[i] = FromInt4Converter::Convert(val); } } }; @@ -573,14 +550,14 @@ struct TensorCaster::ConvertToInt4(in_data[i]); - int8_t high_val = Int4ElementConverter::ConvertToInt4(in_data[i + 1]); + int8_t low_val = ToInt4Converter::ConvertToInt4(in_data[i]); + int8_t high_val = ToInt4Converter::ConvertToInt4(in_data[i + 1]); out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { - int8_t low_val = Int4ElementConverter::ConvertToInt4(in_data[i]); + int8_t low_val = ToInt4Converter::ConvertToInt4(in_data[i]); out_data[i >> 1] = Int4x2(low_val, 0); } @@ -620,14 +597,14 @@ struct TensorCaster::ConvertToUInt4(in_data[i]); - uint8_t high_val = Int4ElementConverter::ConvertToUInt4(in_data[i + 1]); + uint8_t low_val = ToInt4Converter::ConvertToUInt4(in_data[i]); + uint8_t high_val = ToInt4Converter::ConvertToUInt4(in_data[i + 1]); out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { - uint8_t low_val = Int4ElementConverter::ConvertToUInt4(in_data[i]); + uint8_t low_val = ToInt4Converter::ConvertToUInt4(in_data[i]); out_data[i >> 1] = UInt4x2(low_val, 0); } From fad044554b9b5a28ec19baa618ec2a330cc9ac53 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 15:23:10 -0700 Subject: [PATCH 73/88] add IsOrtInt4Type, use UnpackedType, merge specializations --- .../core/providers/cpu/tensor/cast_op.cc | 114 +++++------------- 1 file changed, 32 insertions(+), 82 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index c9bb5311c8879..fe4e19cc62c06 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -63,6 +63,9 @@ template struct IsOrtFloat8Type : std::false_type {}; #endif +template +using IsOrtInt4Type = boost::mp11::mp_contains, T>; + template struct IsStandardIntegerType { static constexpr bool value = @@ -186,20 +189,25 @@ struct EigenCastType { using type = Eigen::bfloat16; }; -template +// Helper for converting (U)Int4x2 values to any destination type. +template ::value>> struct FromInt4Converter { - static DstType Convert(int8_t val) { + // The input 'val' can be either an int8_t value coming from Int4x2.GetElem(pos), + // or an uint8_t value coming from UInt4x2.GetElem(pos), where pos can be 0 or 1. + static DstType Convert(typename SrcType::UnpackedType val) { if constexpr (IsOrtFloat16Type::value) { return DstType(static_cast(val)); } else if constexpr (IsOrtFloat8Type::value) { return DstType(static_cast(val), true); + } else if constexpr (std::is_same_v) { + return val != 0; } else { return static_cast(val); } } }; -// Helper struct for converting from Int4x2/UInt4x2 elements to any destination type template struct ToInt4Converter { // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary @@ -297,7 +305,8 @@ struct TensorCaster { // tensor X -> string template -struct TensorCaster { +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -321,42 +330,29 @@ struct TensorCaster { } }; -template <> -struct TensorCaster { +// tensor (U)Int4x2 -> string +template +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - // Unpack each Int4x2 into two separate string elements + // Unpack each Int4x2/UInt4x2 into two separate string elements size_t out_idx = 0; for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); + // val has type (u)int8_t, so static_cast is required in order for std::to_string + // to interpret val as a number (65 -> "65"), instead of a char (65 -> "A"). out_data[out_idx++] = std::to_string(static_cast(val)); } } }; -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const std::ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Unpack each UInt4x2 into two separate string elements - size_t out_idx = 0; - for (ptrdiff_t i = 0; i < shape_size; ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - - out_data[out_idx++] = std::to_string(static_cast(val)); - } - } -}; - +// tensor string -> Int4x2 template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -388,6 +384,7 @@ struct TensorCaster { } }; +// tensor string -> UInt4x2 template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -438,34 +435,19 @@ struct TensorCaster { } }; -template -struct TensorCaster::value || IsOrtFloat16Type::value || std::is_floating_point_v || IsOrtFloat8Type::value>> { +template +struct TensorCaster::value && + (std::is_same_v || IsStandardIntegerType::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value)>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); + const auto* in_data = in.Data(); auto* out_data = out.MutableData(); for (ptrdiff_t i = 0; i < shape_size; ++i) { // elem 0 is the low nibble, 1 the high nibble auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[i] = FromInt4Converter::Convert(val); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - for (ptrdiff_t i = 0; i < shape_size; ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - - out_data[i] = val != 0; + out_data[i] = FromInt4Converter::Convert(val); } } }; @@ -489,38 +471,6 @@ struct TensorCaster { } }; -template -struct TensorCaster::value || IsOrtFloat16Type::value || std::is_floating_point_v || IsOrtFloat8Type::value>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - for (ptrdiff_t i = 0; i < shape_size; ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - out_data[i] = FromInt4Converter::Convert(val); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - for (ptrdiff_t i = 0; i < shape_size; ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - - out_data[i] = val != 0; - } - } -}; - template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -657,7 +607,7 @@ void CastMLFloat16ThroughFloatTensor( // tensor MLFloat16 -> X template struct TensorCaster && !std::is_same_v>> { + std::enable_if_t::value>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& in, Tensor& out) const { CastMLFloat16ThroughFloatTensor(context, shape, in, out); } @@ -692,10 +642,10 @@ struct TensorCasterNoSat { } }; -// TensorCasterNoSat should never be instantiated for Int4x2/UInt4x2 +// tensor (U)Int4 -> float 8 template struct TensorCasterNoSat || std::is_same_v>> { + std::enable_if_t::value>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) const { // Int4x2/UInt4x2 always fit inside any Float8 type, so we can reuse the saturate = true implementation. TensorCaster{}.Cast(context, shape, src, dst); From 343a22c325382078a57c11d1a6b10338181d0a33 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 16:01:22 -0700 Subject: [PATCH 74/88] specialize ToInt4Converter for bool, merge TensorCaster from bool specializations --- .../core/providers/cpu/tensor/cast_op.cc | 66 ++++--------------- 1 file changed, 14 insertions(+), 52 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index fe4e19cc62c06..3e18f8bea7e1d 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -287,6 +287,17 @@ struct ToInt4Converter { } }; +template <> +struct ToInt4Converter { + static int8_t ConvertToInt4(const bool& val) { + return static_cast(val ? 1 : 0); + } + + static uint8_t ConvertToUInt4(const bool& val) { + return static_cast(val ? 1 : 0); + } +}; + // generic tensor X -> Y template struct TensorCaster { @@ -435,6 +446,7 @@ struct TensorCaster { } }; +// (U)Int4x2 -> integral/floating point types template struct TensorCaster::value && @@ -492,7 +504,7 @@ struct TensorCaster { template struct TensorCaster::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { + std::enable_if_t || IsStandardIntegerType::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -502,36 +514,11 @@ struct TensorCaster::ConvertToInt4(in_data[i]); int8_t high_val = ToInt4Converter::ConvertToInt4(in_data[i + 1]); - out_data[i >> 1] = Int4x2(low_val, high_val); } if (i < shape_size) { int8_t low_val = ToInt4Converter::ConvertToInt4(in_data[i]); - - out_data[i >> 1] = Int4x2(low_val, 0); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - ptrdiff_t i = 0; - for (; i < shape_size - 1; i += 2) { - int8_t low_val = in_data[i] ? 1 : 0; - int8_t high_val = in_data[i + 1] ? 1 : 0; - - out_data[i >> 1] = Int4x2(low_val, high_val); - } - - if (i < shape_size) { - int8_t low_val = in_data[i] ? 1 : 0; - out_data[i >> 1] = Int4x2(low_val, 0); } } @@ -539,7 +526,7 @@ struct TensorCaster { template struct TensorCaster::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { + std::enable_if_t || IsStandardIntegerType::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -549,36 +536,11 @@ struct TensorCaster::ConvertToUInt4(in_data[i]); uint8_t high_val = ToInt4Converter::ConvertToUInt4(in_data[i + 1]); - out_data[i >> 1] = UInt4x2(low_val, high_val); } if (i < shape_size) { uint8_t low_val = ToInt4Converter::ConvertToUInt4(in_data[i]); - - out_data[i >> 1] = UInt4x2(low_val, 0); - } - } -}; - -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - ptrdiff_t i = 0; - for (; i < shape_size - 1; i += 2) { - uint8_t low_val = in_data[i] ? 1 : 0; - uint8_t high_val = in_data[i + 1] ? 1 : 0; - - out_data[i >> 1] = UInt4x2(low_val, high_val); - } - - if (i < shape_size) { - uint8_t low_val = in_data[i] ? 1 : 0; - out_data[i >> 1] = UInt4x2(low_val, 0); } } From 195780ed742a4ea95206c168cb5f68f943b10025 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 16:33:02 -0700 Subject: [PATCH 75/88] refactor ToInt4Converter, merge TensorCaster to Int4 specializations --- .../core/providers/cpu/tensor/cast_op.cc | 116 ++++++++++-------- 1 file changed, 65 insertions(+), 51 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 3e18f8bea7e1d..d0a71446953d2 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -79,6 +79,16 @@ struct IsStandardIntegerType { std::is_same_v; }; +template +struct IsOrtInt4NonStringConversionType { + static constexpr bool value = + std::is_same_v || + IsStandardIntegerType::value || + std::is_floating_point_v || + IsOrtFloat16Type::value || + IsOrtFloat8Type::value; +}; + // string cast helpers // Note: when C++17 is available, use functions @@ -208,63 +218,81 @@ struct FromInt4Converter { } }; -template +template ::value>> struct ToInt4Converter { + static typename DstType::UnpackedType Convert(const SrcType& val); +}; + +template +struct ToInt4Converter { // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary // Casting from fixed point to fixed point: when OOR, discard higher bits and // reinterpret (with respect to two's complement representation for signed types). // For example, 200 (int16) -> -56 (int8). - static int8_t ConvertToInt4(const OtherType& val) { + static int8_t Convert(const SrcType& val) { // Truncate to 4 bits and sign-extend properly uint8_t truncated = static_cast(val) & 0x0F; // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } +}; - static uint8_t ConvertToUInt4(const OtherType& val) { +template +struct ToInt4Converter { + static uint8_t Convert(const SrcType& val) { // Truncate to 4 bits return static_cast(val) & 0x0F; } }; template <> -struct ToInt4Converter { - static int8_t ConvertToInt4(const float& val) { +struct ToInt4Converter { + static int8_t Convert(const float& val) { int result = static_cast(std::roundf(val)); uint8_t truncated = static_cast(result) & 0x0F; return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } +}; - static uint8_t ConvertToUInt4(const float& val) { +template <> +struct ToInt4Converter { + static uint8_t Convert(const float& val) { int result = static_cast(std::roundf(val)); return static_cast(result) & 0x0F; } }; template <> -struct ToInt4Converter { - static int8_t ConvertToInt4(const double& val) { +struct ToInt4Converter { + static int8_t Convert(const double& val) { int result = static_cast(std::round(val)); uint8_t truncated = static_cast(result) & 0x0F; return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } +}; - static uint8_t ConvertToUInt4(const double& val) { +template <> +struct ToInt4Converter { + static uint8_t Convert(const double& val) { int result = static_cast(std::round(val)); return static_cast(result) & 0x0F; } }; template <> -struct ToInt4Converter { - static int8_t ConvertToInt4(const MLFloat16& val) { +struct ToInt4Converter { + static int8_t Convert(const MLFloat16& val) { float f_val = static_cast(val); int result = static_cast(std::roundf(f_val)); uint8_t truncated = static_cast(result) & 0x0F; return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } +}; - static uint8_t ConvertToUInt4(const MLFloat16& val) { +template <> +struct ToInt4Converter { + static uint8_t Convert(const MLFloat16& val) { float f_val = static_cast(val); int result = static_cast(std::roundf(f_val)); return static_cast(result) & 0x0F; @@ -272,15 +300,18 @@ struct ToInt4Converter { }; template <> -struct ToInt4Converter { - static int8_t ConvertToInt4(const BFloat16& val) { +struct ToInt4Converter { + static int8_t Convert(const BFloat16& val) { float f_val = static_cast(val); int result = static_cast(std::roundf(f_val)); uint8_t truncated = static_cast(result) & 0x0F; return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } +}; - static uint8_t ConvertToUInt4(const BFloat16& val) { +template <> +struct ToInt4Converter { + static uint8_t Convert(const BFloat16& val) { float f_val = static_cast(val); int result = static_cast(std::roundf(f_val)); return static_cast(result) & 0x0F; @@ -288,12 +319,15 @@ struct ToInt4Converter { }; template <> -struct ToInt4Converter { - static int8_t ConvertToInt4(const bool& val) { +struct ToInt4Converter { + static int8_t Convert(const bool& val) { return static_cast(val ? 1 : 0); } +}; - static uint8_t ConvertToUInt4(const bool& val) { +template <> +struct ToInt4Converter { + static uint8_t Convert(const bool& val) { return static_cast(val ? 1 : 0); } }; @@ -449,8 +483,7 @@ struct TensorCaster { // (U)Int4x2 -> integral/floating point types template struct TensorCaster::value && - (std::is_same_v || IsStandardIntegerType::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value)>> { + std::enable_if_t::value && IsOrtInt4NonStringConversionType::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -464,6 +497,7 @@ struct TensorCaster UInt4x2 template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -483,6 +517,7 @@ struct TensorCaster { } }; +// UInt4x2 -> Int4x2 template <> struct TensorCaster { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { @@ -502,46 +537,25 @@ struct TensorCaster { } }; -template -struct TensorCaster || IsStandardIntegerType::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - ptrdiff_t i = 0; - for (; i < shape_size - 1; i += 2) { - int8_t low_val = ToInt4Converter::ConvertToInt4(in_data[i]); - int8_t high_val = ToInt4Converter::ConvertToInt4(in_data[i + 1]); - out_data[i >> 1] = Int4x2(low_val, high_val); - } - - if (i < shape_size) { - int8_t low_val = ToInt4Converter::ConvertToInt4(in_data[i]); - out_data[i >> 1] = Int4x2(low_val, 0); - } - } -}; - -template -struct TensorCaster || IsStandardIntegerType::value || std::is_floating_point_v || IsOrtFloat16Type::value || IsOrtFloat8Type::value>> { +// integral/floating point types -> (U)Int4x2 +template +struct TensorCaster::value && IsOrtInt4Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); ptrdiff_t i = 0; for (; i < shape_size - 1; i += 2) { - uint8_t low_val = ToInt4Converter::ConvertToUInt4(in_data[i]); - uint8_t high_val = ToInt4Converter::ConvertToUInt4(in_data[i + 1]); - out_data[i >> 1] = UInt4x2(low_val, high_val); + auto low_val = ToInt4Converter::Convert(in_data[i]); + auto high_val = ToInt4Converter::Convert(in_data[i + 1]); + out_data[i >> 1] = DstType(low_val, high_val); } if (i < shape_size) { - uint8_t low_val = ToInt4Converter::ConvertToUInt4(in_data[i]); - out_data[i >> 1] = UInt4x2(low_val, 0); + auto low_val = ToInt4Converter::Convert(in_data[i]); + out_data[i >> 1] = DstType(low_val, 0); } } }; From 506a48b7ed0995a6353b430f16fa9bf6d97aff3c Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 16:51:29 -0700 Subject: [PATCH 76/88] enforce float 8 DstType for TensorCasterNoSat --- .../core/providers/cpu/tensor/cast_op.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index d0a71446953d2..6766c134611c0 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -218,6 +218,7 @@ struct FromInt4Converter { } }; +// Helper for converting any source type to (U)Int4x2::UnpackedType values (int8_t and uint8_t). template ::value>> struct ToInt4Converter { @@ -560,6 +561,7 @@ struct TensorCaster { }; #endif -#if !defined(DISABLE_FLOAT8_TYPES) -// TensorCasterNoSat is only called when all the below conditions are met (see Cast::Compute implementation): +#if !defined(DISABLE_FLOAT8_TYPES) +// TensorCasterNoSat is only called when all the below conditions are met (see Cast::Compute): // - defined(DISABLE_FLOAT8_TYPES) == false // - saturate_ == false // - IsOrtFloat8Type::value == true // tensor X -> float 8 -template +template ::value>> struct TensorCasterNoSat { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); @@ -618,10 +621,10 @@ struct TensorCasterNoSat { } }; -// tensor (U)Int4 -> float 8 +// tensor (U)Int4x2 -> float 8 template struct TensorCasterNoSat::value>> { + std::enable_if_t::value && IsOrtFloat8Type::value>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) const { // Int4x2/UInt4x2 always fit inside any Float8 type, so we can reuse the saturate = true implementation. TensorCaster{}.Cast(context, shape, src, dst); @@ -630,7 +633,8 @@ struct TensorCasterNoSat float 8 template -struct TensorCasterNoSat { +struct TensorCasterNoSat::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -645,6 +649,7 @@ struct TensorCasterNoSat { #endif + class Cast final : public OpKernel { public: Cast(const OpKernelInfo& info) : OpKernel(info) { From bc9edef1268b0e9d1bbf087587f51375f0d45359 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 17:23:14 -0700 Subject: [PATCH 77/88] refactor ToInt4Converter, merge specializations --- .../core/providers/cpu/tensor/cast_op.cc | 65 ++++++------------- 1 file changed, 19 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 6766c134611c0..5d7c28ab75d5d 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -226,7 +226,8 @@ struct ToInt4Converter { }; template -struct ToInt4Converter { +struct ToInt4Converter::value>> { // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary // Casting from fixed point to fixed point: when OOR, discard higher bits and // reinterpret (with respect to two's complement representation for signed types). @@ -240,7 +241,8 @@ struct ToInt4Converter { }; template -struct ToInt4Converter { +struct ToInt4Converter::value>> { static uint8_t Convert(const SrcType& val) { // Truncate to 4 bits return static_cast(val) & 0x0F; @@ -256,14 +258,6 @@ struct ToInt4Converter { } }; -template <> -struct ToInt4Converter { - static uint8_t Convert(const float& val) { - int result = static_cast(std::roundf(val)); - return static_cast(result) & 0x0F; - } -}; - template <> struct ToInt4Converter { static int8_t Convert(const double& val) { @@ -273,56 +267,35 @@ struct ToInt4Converter { } }; -template <> -struct ToInt4Converter { - static uint8_t Convert(const double& val) { - int result = static_cast(std::round(val)); - return static_cast(result) & 0x0F; - } -}; - -template <> -struct ToInt4Converter { - static int8_t Convert(const MLFloat16& val) { - float f_val = static_cast(val); - int result = static_cast(std::roundf(f_val)); - uint8_t truncated = static_cast(result) & 0x0F; - return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); - } -}; - -template <> -struct ToInt4Converter { - static uint8_t Convert(const MLFloat16& val) { +template +struct ToInt4Converter::value && IsOrtInt4Type::value>> { + static typename DstType::UnpackedType Convert(const SrcType& val) { float f_val = static_cast(val); - int result = static_cast(std::roundf(f_val)); - return static_cast(result) & 0x0F; + return ToInt4Converter::Convert(f_val); } }; template <> -struct ToInt4Converter { - static int8_t Convert(const BFloat16& val) { - float f_val = static_cast(val); - int result = static_cast(std::roundf(f_val)); - uint8_t truncated = static_cast(result) & 0x0F; - return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); +struct ToInt4Converter { + static int8_t Convert(const bool& val) { + return static_cast(val ? 1 : 0); } }; template <> -struct ToInt4Converter { - static uint8_t Convert(const BFloat16& val) { - float f_val = static_cast(val); - int result = static_cast(std::roundf(f_val)); +struct ToInt4Converter { + static uint8_t Convert(const float& val) { + int result = static_cast(std::roundf(val)); return static_cast(result) & 0x0F; } }; template <> -struct ToInt4Converter { - static int8_t Convert(const bool& val) { - return static_cast(val ? 1 : 0); +struct ToInt4Converter { + static uint8_t Convert(const double& val) { + int result = static_cast(std::round(val)); + return static_cast(result) & 0x0F; } }; From 4d652f0453c5b3032eb0075417346741d00f7331 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 18:20:32 -0700 Subject: [PATCH 78/88] Refactor ToInt4Converter, merge specializations --- .../core/providers/cpu/tensor/cast_op.cc | 90 ++++++++----------- 1 file changed, 39 insertions(+), 51 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 5d7c28ab75d5d..eca181e72f820 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -201,7 +201,7 @@ struct EigenCastType { // Helper for converting (U)Int4x2 values to any destination type. template ::value>> + typename Enable = std::enable_if_t::value && IsOrtInt4NonStringConversionType::value>> struct FromInt4Converter { // The input 'val' can be either an int8_t value coming from Int4x2.GetElem(pos), // or an uint8_t value coming from UInt4x2.GetElem(pos), where pos can be 0 or 1. @@ -220,18 +220,19 @@ struct FromInt4Converter { // Helper for converting any source type to (U)Int4x2::UnpackedType values (int8_t and uint8_t). template ::value>> + typename Enable = std::enable_if_t::value && IsOrtInt4Type::value>> struct ToInt4Converter { static typename DstType::UnpackedType Convert(const SrcType& val); }; template struct ToInt4Converter::value>> { - // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary - // Casting from fixed point to fixed point: when OOR, discard higher bits and - // reinterpret (with respect to two's complement representation for signed types). - // For example, 200 (int16) -> -56 (int8). + std::enable_if_t::value>> { + // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from + // fixed point to fixed point: when OOR, discard higher bits and reinterpret + // (with respect to two's complement representation for signed types). + // The following example is listed: 200 (int16) converts to -56 (int8). + // For our int4 conversion, 200 (int16) would convert to -8 (int4). static int8_t Convert(const SrcType& val) { // Truncate to 4 bits and sign-extend properly uint8_t truncated = static_cast(val) & 0x0F; @@ -242,67 +243,57 @@ struct ToInt4Converter struct ToInt4Converter::value>> { + std::enable_if_t::value>> { + // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from + // fixed point to fixed point: when OOR, discard higher bits and reinterpret + // (with respect to two's complement representation for signed types). static uint8_t Convert(const SrcType& val) { // Truncate to 4 bits return static_cast(val) & 0x0F; } }; -template <> -struct ToInt4Converter { - static int8_t Convert(const float& val) { +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const bool& val) { + return static_cast(val ? 1 : 0); + } +}; + +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const float& val) { int result = static_cast(std::roundf(val)); - uint8_t truncated = static_cast(result) & 0x0F; - return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); + return ToInt4Converter::Convert(result); } }; -template <> -struct ToInt4Converter { - static int8_t Convert(const double& val) { +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const double& val) { int result = static_cast(std::round(val)); - uint8_t truncated = static_cast(result) & 0x0F; - return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); + return ToInt4Converter::Convert(result); } }; template struct ToInt4Converter::value && IsOrtInt4Type::value>> { + std::enable_if_t::value && IsOrtInt4Type::value>> { static typename DstType::UnpackedType Convert(const SrcType& val) { - float f_val = static_cast(val); - return ToInt4Converter::Convert(f_val); + float result = val.ToFloat(); + return ToInt4Converter::Convert(result); } }; -template <> -struct ToInt4Converter { - static int8_t Convert(const bool& val) { - return static_cast(val ? 1 : 0); - } -}; - -template <> -struct ToInt4Converter { - static uint8_t Convert(const float& val) { - int result = static_cast(std::roundf(val)); - return static_cast(result) & 0x0F; - } -}; - -template <> -struct ToInt4Converter { - static uint8_t Convert(const double& val) { - int result = static_cast(std::round(val)); - return static_cast(result) & 0x0F; - } -}; - -template <> -struct ToInt4Converter { - static uint8_t Convert(const bool& val) { - return static_cast(val ? 1 : 0); +template +struct ToInt4Converter::value && IsOrtInt4Type::value>> { + static typename DstType::UnpackedType Convert(const SrcType& val) { + float f_val = static_cast(val); + return ToInt4Converter::Convert(f_val); } }; @@ -534,7 +525,6 @@ struct TensorCaster { }; #endif - #if !defined(DISABLE_FLOAT8_TYPES) // TensorCasterNoSat is only called when all the below conditions are met (see Cast::Compute): // - defined(DISABLE_FLOAT8_TYPES) == false @@ -622,7 +611,6 @@ struct TensorCasterNoSat Date: Sat, 19 Jul 2025 18:39:10 -0700 Subject: [PATCH 79/88] rename type, update comments --- .../core/providers/cpu/tensor/cast_op.cc | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index eca181e72f820..858e0ef392281 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -79,8 +79,9 @@ struct IsStandardIntegerType { std::is_same_v; }; +// Types that Int4x2 and UInt4x2 convert to and from, apart from string. template -struct IsOrtInt4NonStringConversionType { +struct IsOrtInt4NumericConversionType { static constexpr bool value = std::is_same_v || IsStandardIntegerType::value || @@ -201,7 +202,7 @@ struct EigenCastType { // Helper for converting (U)Int4x2 values to any destination type. template ::value && IsOrtInt4NonStringConversionType::value>> + typename Enable = std::enable_if_t::value && IsOrtInt4NumericConversionType::value>> struct FromInt4Converter { // The input 'val' can be either an int8_t value coming from Int4x2.GetElem(pos), // or an uint8_t value coming from UInt4x2.GetElem(pos), where pos can be 0 or 1. @@ -220,19 +221,19 @@ struct FromInt4Converter { // Helper for converting any source type to (U)Int4x2::UnpackedType values (int8_t and uint8_t). template ::value && IsOrtInt4Type::value>> + typename Enable = std::enable_if_t::value && IsOrtInt4Type::value>> struct ToInt4Converter { static typename DstType::UnpackedType Convert(const SrcType& val); }; +// See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from +// fixed point to fixed point: when OOR, discard higher bits and reinterpret +// (with respect to two's complement representation for signed types). +// The following example is listed: 200 (int16) converts to -56 (int8). +// For our int4 conversion, 200 (int16) would convert to -8 (int4). template struct ToInt4Converter::value>> { - // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from - // fixed point to fixed point: when OOR, discard higher bits and reinterpret - // (with respect to two's complement representation for signed types). - // The following example is listed: 200 (int16) converts to -56 (int8). - // For our int4 conversion, 200 (int16) would convert to -8 (int4). static int8_t Convert(const SrcType& val) { // Truncate to 4 bits and sign-extend properly uint8_t truncated = static_cast(val) & 0x0F; @@ -241,12 +242,12 @@ struct ToInt4Converter struct ToInt4Converter::value>> { - // See https://onnx.ai/onnx/operators/onnx__Cast.html#summary for casting from - // fixed point to fixed point: when OOR, discard higher bits and reinterpret - // (with respect to two's complement representation for signed types). static uint8_t Convert(const SrcType& val) { // Truncate to 4 bits return static_cast(val) & 0x0F; @@ -261,6 +262,8 @@ struct ToInt4Converter struct ToInt4Converter::value>> { @@ -448,7 +451,7 @@ struct TensorCaster { // (U)Int4x2 -> integral/floating point types template struct TensorCaster::value && IsOrtInt4NonStringConversionType::value>> { + std::enable_if_t::value && IsOrtInt4NumericConversionType::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -505,7 +508,7 @@ struct TensorCaster { // integral/floating point types -> (U)Int4x2 template struct TensorCaster::value && IsOrtInt4Type::value>> { + std::enable_if_t::value && IsOrtInt4Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); From fa705d63c4faf90b48fc1c5be09a2e58cff97561 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 18:57:42 -0700 Subject: [PATCH 80/88] small refactor --- .../core/providers/cpu/tensor/cast_op.cc | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 858e0ef392281..cf6843a7b97a2 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -448,7 +448,7 @@ struct TensorCaster { } }; -// (U)Int4x2 -> integral/floating point types +// (U)Int4x2 -> numeric types template struct TensorCaster::value && IsOrtInt4NumericConversionType::value>> { @@ -465,6 +465,29 @@ struct TensorCaster (U)Int4x2 +template +struct TensorCaster::value && IsOrtInt4Type::value>> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + ptrdiff_t i = 0; + for (; i < shape_size - 1; i += 2) { + auto low_val = ToInt4Converter::Convert(in_data[i]); + auto high_val = ToInt4Converter::Convert(in_data[i + 1]); + out_data[i >> 1] = DstType(low_val, high_val); + } + + if (i < shape_size) { + auto low_val = ToInt4Converter::Convert(in_data[i]); + out_data[i >> 1] = DstType(low_val, 0); + } + } +}; + // Int4x2 -> UInt4x2 template <> struct TensorCaster { @@ -505,29 +528,6 @@ struct TensorCaster { } }; -// integral/floating point types -> (U)Int4x2 -template -struct TensorCaster::value && IsOrtInt4Type::value>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - ptrdiff_t i = 0; - for (; i < shape_size - 1; i += 2) { - auto low_val = ToInt4Converter::Convert(in_data[i]); - auto high_val = ToInt4Converter::Convert(in_data[i + 1]); - out_data[i >> 1] = DstType(low_val, high_val); - } - - if (i < shape_size) { - auto low_val = ToInt4Converter::Convert(in_data[i]); - out_data[i >> 1] = DstType(low_val, 0); - } - } -}; - #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific From 4a65285de4dc86f692f7a8bd9413a9af5fd780f6 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 21:00:16 -0700 Subject: [PATCH 81/88] update comments, test values --- .../core/providers/cpu/tensor/cast_op.cc | 14 +++++++-- .../test/providers/cpu/tensor/cast_op_test.cc | 31 +++++++++++-------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index cf6843a7b97a2..b1e8077fb559a 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -235,9 +235,17 @@ template struct ToInt4Converter::value>> { static int8_t Convert(const SrcType& val) { - // Truncate to 4 bits and sign-extend properly + // Example: int8_t(14) converts to int4 (-2) + // int8_t(14) is 0000_1110 + // truncate: 0000_1110 & 0000_1111 = 0000_1110 + // in 4-bit two's complement, 1110 = 1 * -8 + 1 * 4 + 1 * 2 + 1 * 0 = -2 + // sign-extend: -2 in int8_t is 1111_0000 | 0000_1110 = 1111_1110 + + // Truncate to 4 least significant bits uint8_t truncated = static_cast(val) & 0x0F; - // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement + + // Sign-extend: if bit 3 is set, it's negative in 4-bit two's complement, + // so set the 4 most significant bits to 1. return static_cast((truncated & 0x8) ? (truncated | 0xF0) : truncated); } }; @@ -249,7 +257,7 @@ template struct ToInt4Converter::value>> { static uint8_t Convert(const SrcType& val) { - // Truncate to 4 bits + // Truncate to 4 least significant bits return static_cast(val) & 0x0F; } }; diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index c24913d71d72f..8733484308381 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -233,8 +233,8 @@ TEST(CastOpTest, Int4x2ToUInt8) { Int4x2(6, 2) // both positive }; - // -8 becomes 248, -1 becomes 255, etc. - const std::vector expected_uint8_output = {248, 7, 0, 255, 3, 251, 6, 2}; + // Negative values will be cast to their unsigned representation + const std::vector expected_uint8_output = {248, 7, 0, UINT8_MAX, 3, 251, 6, 2}; // WHEN, THEN TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint8_output), shape); @@ -267,7 +267,7 @@ TEST(CastOpTest, Int4x2ToUInt16) { }; // Negative values will be cast to their unsigned representation - const std::vector expected_uint16_output = {65528, 7, 0, 65535, 3, 65531, 6, 2}; + const std::vector expected_uint16_output = {65528, 7, 0, UINT16_MAX, 3, 65531, 6, 2}; // WHEN, THEN TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_uint16_output), shape); @@ -698,7 +698,7 @@ TEST(CastOpTest, Int4x2ToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; const std::vector int4x2_input = { - Int4x2(-8, 7), // negative values + Int4x2(-8, 7), // min and max values Int4x2(0, -1), // -1 becomes max unsigned value Int4x2(3, -5), // positive and negative values Int4x2(6, 2) // positive values @@ -739,12 +739,17 @@ TEST(CastOpTest, UInt4x2ToInt4x2) { TEST(CastOpTest, Int8ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int8_input = {-10, 15, 0, -1, 3, -5, -128, 127}; + const std::vector int8_input = {-10, 15, 0, -1, 7, -8, -128, 127}; const std::vector expected_int4x2_output = { + // 10 in binary is 00001010. + // Invert all bits -> 11110101, add 1 -> 11110110 + // So -10 in binary is 11110110. + // Truncate to 4 least significant bits -> 0110. + // In 4-bit two's complement, 0110 = 0 * -8 + 1 * 4 + 1 * 2 = 6. Int4x2(6, -1), // -10 truncated to 6, 15 truncated to -1 Int4x2(0, -1), // 0 unchanged, -1 unchanged - Int4x2(3, -5), // 3 unchanged, -5 unchanged + Int4x2(7, -8), // 7 unchanged, -8 unchanged Int4x2(0, -1) // -128 truncated to 0, 127 truncated to -1 }; @@ -772,13 +777,13 @@ TEST(CastOpTest, UInt8ToUInt4x2) { TEST(CastOpTest, Int16ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int16_input = {-10, 32767, 0, -32768, 3, -5, 240, 31}; + const std::vector int16_input = {-10, 32767, 0, -32768, 7, -8, 240, 31}; // values get truncated to lower 4 bits and sign-extended const std::vector expected_int4x2_output = { Int4x2(6, -1), // -10 (0xFFF6) truncated to 6, 32767 (0x7FFF) truncated to -1 Int4x2(0, 0), // 0 (0x0000) truncated to 0, -32768 (0x8000) truncated to 0 - Int4x2(3, -5), // 3 (0x0003) truncated to 3, -5 (0xFFFB) truncated to -5 + Int4x2(7, -8), // 7 (0x0007) truncated to 7, -8 (0xFFF8) truncated to -8 Int4x2(0, -1) // 240 (0x00F0) truncated to 0, 31 (0x001F) truncated to -1 }; @@ -865,13 +870,13 @@ TEST(CastOpTest, UInt32ToUInt4x2) { TEST(CastOpTest, Int64ToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int64_input = {-10, INT64_MAX, 0, INT64_MIN, 3, -5, 65520, 4111}; + const std::vector int64_input = {-10, INT64_MAX, 0, INT64_MIN, 7, -8, 65520, 4111}; // values get truncated to lower 4 bits and sign-extended const std::vector expected_int4x2_output = { Int4x2(6, -1), // -10 truncated to 6, 9223372036854775807 truncated to -1 Int4x2(0, 0), // 0 truncated to 0, -9223372036854775808 truncated to 0 - Int4x2(3, -5), // 3 truncated to 3, -5 truncated to -5 + Int4x2(7, -8), // 7 truncated to 7, -8 truncated to -8 Int4x2(0, -1) // 65520 truncated to 0, 4111 truncated to -1 }; @@ -899,12 +904,12 @@ TEST(CastOpTest, UInt64ToUInt4x2) { TEST(CastOpTest, FloatToInt4x2) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector float_input = {-10.7f, 15.3f, 0.4f, -1.6f, 3.8f, -5.2f, 240.1f, 31.9f}; + const std::vector float_input = {-10.7f, 15.3f, 0.4f, -1.6f, 7.0f, -8.0f, 240.1f, 31.9f}; const std::vector expected_int4x2_output = { Int4x2(5, -1), // -10.7 rounded to -11 (0xF5), truncated to 5, sign-extended to 5; 15.3 rounded to 15 (0x0F), sign-extended to -1 Int4x2(0, -2), // 0.4 rounded to 0; -1.6 rounded to -2 (0xFE), truncated to 14 (0x0E), sign-extended to -2 - Int4x2(4, -5), // 3.8 rounded to 4; -5.2 rounded to -5 (0xFB), truncated to 11 (0x0B), sign-extended to -5 + Int4x2(7, -8), // 7.0 converted to 7; -8.0 converted to -8 Int4x2(0, 0) // 240.1 rounded to 240 (0xF0), truncated to 0; 31.9 rounded to 32 (0x20), truncated to 0 }; @@ -1191,7 +1196,7 @@ TEST(CastOpTest, StringToUInt4x2) { TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); } -TEST(CastOpTest, String2UInt4x2BoundaryValues) { +TEST(CastOpTest, StringToUInt4x2BoundaryValues) { // GIVEN // Test string values that need truncation to UInt4x2 range const std::vector shape{3, 2}; From 226c6dc3d345c4826f2c36e31c6779e2e7525ea0 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 21:32:49 -0700 Subject: [PATCH 82/88] Add 2 tests --- .../test/providers/cpu/tensor/cast_op_test.cc | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 8733484308381..7542ce4942719 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1376,6 +1376,48 @@ TEST(CastOpTest, UInt4x2ToFloat8E5M2) { OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); } +TEST(CastOpTest, Float8E4M3FNToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector uint_float8_input; + const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_input_values) { + uint_float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_uint4x2_output = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(uint_float8_input), gsl::make_span(expected_uint4x2_output), shape); +} + #endif } // namespace test From 6d97d37a1c42dedc28561a0b9d66bc1dbe7d4f56 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 22:02:53 -0700 Subject: [PATCH 83/88] merge int4 -> string specialization with int4 -> numeric --- .../core/providers/cpu/tensor/cast_op.cc | 37 ++++++------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index b1e8077fb559a..10e20770157c9 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -90,6 +90,11 @@ struct IsOrtInt4NumericConversionType { IsOrtFloat8Type::value; }; +template +struct IsOrtInt4ConversionType { + static constexpr bool value = IsOrtInt4NumericConversionType::value || std::is_same_v; +}; + // string cast helpers // Note: when C++17 is available, use functions @@ -202,7 +207,7 @@ struct EigenCastType { // Helper for converting (U)Int4x2 values to any destination type. template ::value && IsOrtInt4NumericConversionType::value>> + typename Enable = std::enable_if_t::value && IsOrtInt4ConversionType::value>> struct FromInt4Converter { // The input 'val' can be either an int8_t value coming from Int4x2.GetElem(pos), // or an uint8_t value coming from UInt4x2.GetElem(pos), where pos can be 0 or 1. @@ -213,6 +218,10 @@ struct FromInt4Converter { return DstType(static_cast(val), true); } else if constexpr (std::is_same_v) { return val != 0; + } else if constexpr (std::is_same_v) { + // val has type (u)int8_t, so static_cast is required in order for std::to_string + // to interpret val as a number (65 -> "65"), instead of a char (65 -> "A"). + return std::to_string(static_cast(val)); } else { return static_cast(val); } @@ -351,28 +360,6 @@ struct TensorCaster { } }; -// tensor (U)Int4x2 -> string -template -struct TensorCaster::value>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const std::ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Unpack each Int4x2/UInt4x2 into two separate string elements - size_t out_idx = 0; - for (ptrdiff_t i = 0; i < shape_size; ++i) { - // elem 0 is the low nibble, 1 the high nibble - auto val = in_data[i >> 1].GetElem(i & 0x1); - - // val has type (u)int8_t, so static_cast is required in order for std::to_string - // to interpret val as a number (65 -> "65"), instead of a char (65 -> "A"). - out_data[out_idx++] = std::to_string(static_cast(val)); - } - } -}; - // tensor string -> Int4x2 template <> struct TensorCaster { @@ -456,10 +443,10 @@ struct TensorCaster { } }; -// (U)Int4x2 -> numeric types +// (U)Int4x2 -> string or numeric types template struct TensorCaster::value && IsOrtInt4NumericConversionType::value>> { + std::enable_if_t::value && IsOrtInt4ConversionType::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); From be6c17672aa65e2dbf9f9b0bef44c73c75ed6660 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 22:14:34 -0700 Subject: [PATCH 84/88] reuse ToInt4Converter inside string specializations --- onnxruntime/core/providers/cpu/tensor/cast_op.cc | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 10e20770157c9..d16db3142adb8 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -368,23 +368,19 @@ struct TensorCaster { const auto* in_data = in.Data(); auto* out_data = out.MutableData(); - auto truncateToLower4BitsAndSignExtend = [](auto val) { - return (val & 0xF) | (-(val & 0x8) & 0xF0); - }; - // Every 2 strings combine into 1 Int4x2 const ptrdiff_t out_size = (shape_size + 1) >> 1; for (ptrdiff_t i = 0; i < out_size; ++i) { const ptrdiff_t in_idx = i << 1; int v0 = std::stoi(in_data[in_idx]); - int8_t val0 = static_cast(truncateToLower4BitsAndSignExtend(v0)); + int8_t val0 = ToInt4Converter::Convert(v0); // Parse second value (or use 0 if odd number of elements) int8_t val1 = 0; if (in_idx + 1 < shape_size) { int v1 = std::stoi(in_data[in_idx + 1]); - val1 = static_cast(truncateToLower4BitsAndSignExtend(v1)); + val1 = ToInt4Converter::Convert(v1); } out_data[i] = Int4x2(val0, val1); @@ -407,13 +403,13 @@ struct TensorCaster { // Parse first value and truncate to lower 4 bits int v0 = std::stoi(in_data[in_idx]); - uint8_t val0 = static_cast(v0 & 0xF); + uint8_t val0 = ToInt4Converter::Convert(v0); // Parse second value (or use 0 if odd number of elements) uint8_t val1 = 0; if (in_idx + 1 < shape_size) { int v1 = std::stoi(in_data[in_idx + 1]); - val1 = static_cast(v1 & 0xF); + val1 = ToInt4Converter::Convert(v1); } out_data[i] = UInt4x2(val0, val1); From 003adaebe1d0a26669bd4607c2058f85c7afabea Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 22:25:50 -0700 Subject: [PATCH 85/88] Merge string -> int4 and string -> uint4 specializations --- .../core/providers/cpu/tensor/cast_op.cc | 57 ++++++------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index d16db3142adb8..28223559a8c88 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -333,7 +333,7 @@ struct TensorCaster { } }; -// tensor X -> string +// tensor X -> string, if X != (U)Int4x2 template struct TensorCaster::value>> { @@ -347,9 +347,10 @@ struct TensorCaster X +// tensor string -> X, if X != (U)Int4x2 template -struct TensorCaster { +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -360,59 +361,33 @@ struct TensorCaster { } }; -// tensor string -> Int4x2 -template <> -struct TensorCaster { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Every 2 strings combine into 1 Int4x2 - const ptrdiff_t out_size = (shape_size + 1) >> 1; - for (ptrdiff_t i = 0; i < out_size; ++i) { - const ptrdiff_t in_idx = i << 1; - - int v0 = std::stoi(in_data[in_idx]); - int8_t val0 = ToInt4Converter::Convert(v0); - - // Parse second value (or use 0 if odd number of elements) - int8_t val1 = 0; - if (in_idx + 1 < shape_size) { - int v1 = std::stoi(in_data[in_idx + 1]); - val1 = ToInt4Converter::Convert(v1); - } - - out_data[i] = Int4x2(val0, val1); - } - } -}; - -// tensor string -> UInt4x2 -template <> -struct TensorCaster { +// tensor string -> (U)Int4x2 +template +struct TensorCaster::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); + auto* out_data = out.MutableData(); - // Every 2 strings combine into 1 UInt4x2 + // Every 2 strings combine into 1 (U)Int4x2 const ptrdiff_t out_size = (shape_size + 1) >> 1; for (ptrdiff_t i = 0; i < out_size; ++i) { const ptrdiff_t in_idx = i << 1; - // Parse first value and truncate to lower 4 bits + // Parse first value and truncate to lower 4 bits. + // Sign extend if needed for Int4x2. int v0 = std::stoi(in_data[in_idx]); - uint8_t val0 = ToInt4Converter::Convert(v0); + typename DstType::UnpackedType val0 = ToInt4Converter::Convert(v0); // Parse second value (or use 0 if odd number of elements) - uint8_t val1 = 0; + typename DstType::UnpackedType val1 = 0; if (in_idx + 1 < shape_size) { int v1 = std::stoi(in_data[in_idx + 1]); - val1 = ToInt4Converter::Convert(v1); + val1 = ToInt4Converter::Convert(v1); } - out_data[i] = UInt4x2(val0, val1); + out_data[i] = DstType(val0, val1); } } }; From 27040c1fada9f4671947dd5d1d7e9be8af8fa6f6 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 22:52:59 -0700 Subject: [PATCH 86/88] parse string as double, add test --- .../core/providers/cpu/tensor/cast_op.cc | 8 +++---- .../test/providers/cpu/tensor/cast_op_test.cc | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 28223559a8c88..121e4e60a3ce7 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -377,14 +377,14 @@ struct TensorCaster::Convert(v0); + double v0 = std::stod(in_data[in_idx]); + typename DstType::UnpackedType val0 = ToInt4Converter::Convert(v0); // Parse second value (or use 0 if odd number of elements) typename DstType::UnpackedType val1 = 0; if (in_idx + 1 < shape_size) { - int v1 = std::stoi(in_data[in_idx + 1]); - val1 = ToInt4Converter::Convert(v1); + double v1 = std::stod(in_data[in_idx + 1]); + val1 = ToInt4Converter::Convert(v1); } out_data[i] = DstType(val0, val1); diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 7542ce4942719..c9dd7848e1db7 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1218,6 +1218,27 @@ TEST(CastOpTest, StringToUInt4x2BoundaryValues) { TestCastOp(gsl::span(string_input), gsl::span(expected_output), shape); } +TEST(CastOpTest, FloatStringToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "-10.7", "255.3", + "0.4", "2", + "6.8", "240.2", + "15.0", "-8" + }; + + const std::vector expected_int4x2_output = { + Int4x2(5, -1), // -11 -> 5, 255 -> -1 + Int4x2(0, 2), + Int4x2(7, 0), + Int4x2(-1, -8) + }; + + // WHEN, THEN + TestCastOp(gsl::span(string_input), gsl::span(expected_int4x2_output), shape); +} + #if !defined(DISABLE_FLOAT8_TYPES) template From 9a64f1cc981f9df5809d6d120d0e6c983ebb0617 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Sat, 19 Jul 2025 23:10:16 -0700 Subject: [PATCH 87/88] Merge string -> int4 with numeric -> int4 specializations --- .../core/providers/cpu/tensor/cast_op.cc | 54 +++++++------------ .../test/providers/cpu/tensor/cast_op_test.cc | 8 ++- 2 files changed, 22 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 121e4e60a3ce7..685937049e58f 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -230,7 +230,7 @@ struct FromInt4Converter { // Helper for converting any source type to (U)Int4x2::UnpackedType values (int8_t and uint8_t). template ::value && IsOrtInt4Type::value>> + typename Enable = std::enable_if_t::value && IsOrtInt4Type::value>> struct ToInt4Converter { static typename DstType::UnpackedType Convert(const SrcType& val); }; @@ -271,14 +271,16 @@ struct ToInt4Converter (U)Int4x2 template struct ToInt4Converter::value>> { - static typename DstType::UnpackedType Convert(const bool& val) { + static typename DstType::UnpackedType Convert(bool val) { return static_cast(val ? 1 : 0); } }; +// float -> (U)Int4x2 // Per https://onnx.ai/onnx/operators/onnx__Cast.html#summary, casting from // floating point to fixed point is undefined if OOR. template @@ -290,6 +292,7 @@ struct ToInt4Converter (U)Int4x2 template struct ToInt4Converter::value>> { @@ -299,6 +302,7 @@ struct ToInt4Converter (U)Int4x2 template struct ToInt4Converter::value && IsOrtInt4Type::value>> { @@ -308,6 +312,7 @@ struct ToInt4Converter (U)Int4x2 template struct ToInt4Converter::value && IsOrtInt4Type::value>> { @@ -317,6 +322,16 @@ struct ToInt4Converter (U)Int4x2 +template +struct ToInt4Converter::value>> { + static typename DstType::UnpackedType Convert(const std::string& val) { + double result = std::stod(val); + return ToInt4Converter::Convert(result); + } +}; + // generic tensor X -> Y template struct TensorCaster { @@ -361,37 +376,6 @@ struct TensorCaster (U)Int4x2 -template -struct TensorCaster::value>> { - void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { - const ptrdiff_t shape_size = narrow(shape.Size()); - const auto* in_data = in.Data(); - auto* out_data = out.MutableData(); - - // Every 2 strings combine into 1 (U)Int4x2 - const ptrdiff_t out_size = (shape_size + 1) >> 1; - for (ptrdiff_t i = 0; i < out_size; ++i) { - const ptrdiff_t in_idx = i << 1; - - // Parse first value and truncate to lower 4 bits. - // Sign extend if needed for Int4x2. - double v0 = std::stod(in_data[in_idx]); - typename DstType::UnpackedType val0 = ToInt4Converter::Convert(v0); - - // Parse second value (or use 0 if odd number of elements) - typename DstType::UnpackedType val1 = 0; - if (in_idx + 1 < shape_size) { - double v1 = std::stod(in_data[in_idx + 1]); - val1 = ToInt4Converter::Convert(v1); - } - - out_data[i] = DstType(val0, val1); - } - } -}; - // tensor MLFloat16 -> float template <> struct TensorCaster { @@ -431,10 +415,10 @@ struct TensorCaster (U)Int4x2 +// string or numeric types -> (U)Int4x2 template struct TensorCaster::value && IsOrtInt4Type::value>> { + std::enable_if_t::value && IsOrtInt4Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index c9dd7848e1db7..fefa98536d319 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1225,15 +1225,13 @@ TEST(CastOpTest, FloatStringToInt4x2) { "-10.7", "255.3", "0.4", "2", "6.8", "240.2", - "15.0", "-8" - }; + "15.0", "-8"}; const std::vector expected_int4x2_output = { - Int4x2(5, -1), // -11 -> 5, 255 -> -1 + Int4x2(5, -1), // -11 -> 5, 255 -> -1 Int4x2(0, 2), Int4x2(7, 0), - Int4x2(-1, -8) - }; + Int4x2(-1, -8)}; // WHEN, THEN TestCastOp(gsl::span(string_input), gsl::span(expected_int4x2_output), shape); From 416635f84299bf16d4a049404f53589d979ede41 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Tue, 22 Jul 2025 09:01:50 -0700 Subject: [PATCH 88/88] Add test --- .../test/providers/cpu/tensor/cast_op_test.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index fefa98536d319..68d4f3559504a 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1416,6 +1416,26 @@ TEST(CastOpTest, Float8E4M3FNToInt4x2) { TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); } +TEST(CastOpTest, Float8E4M3FNToInt4x2_OddShape) { + // GIVEN + const std::vector shape{1, 2, 3}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + TEST(CastOpTest, Float8E4M3FNToUInt4x2) { // GIVEN const std::vector shape{2, 2, 2};