diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index e4c7830ffbb55..6bcbdc401619a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -207,15 +207,29 @@ Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto ORT_RETURN_IF_ERROR( GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size)); - unpacked_tensor.resize(tensor_byte_size); - if (external_file_path == kTensorProtoMemoryAddressTag) { // The external data is in the same memory as the tensor proto. // The offset is the address of the data. + unpacked_tensor.resize(tensor_byte_size); std::memcpy(unpacked_tensor.data(), reinterpret_cast(file_offset), tensor_byte_size); return Status::OK(); } + // Validate that the external file is large enough before allocating. + // This protects against a model with a huge declared shape but a missing/short external file. + std::error_code fs_error_code{}; + std::uintmax_t file_length = std::filesystem::file_size(external_file_path, fs_error_code); + ORT_RETURN_IF(fs_error_code, "Failed to get file size for external initializer ", tensor_proto.name(), + ". std::filesystem error: ", fs_error_code.message(), " (value: ", fs_error_code.value(), ")"); + SafeInt end_of_read(file_offset); + end_of_read += tensor_byte_size; + ORT_RETURN_IF(file_offset < 0 || static_cast(end_of_read) > file_length, + "External initializer: ", tensor_proto.name(), " offset: ", file_offset, + " size to read: ", static_cast(tensor_byte_size), " given file_length: ", file_length, + " are out of bounds or cannot be read in full."); + + unpacked_tensor.resize(tensor_byte_size); + ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( external_file_path.c_str(), file_offset, @@ -239,6 +253,13 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor requires allocator to be provided."); } + // Validate data consistency and enforce size limits before allocating memory. + // This prevents a malicious model from triggering a massive allocation with a + // large declared shape that has no/insufficient actual data. + if (!utils::HasExternalData(tensor_proto)) { + ORT_RETURN_IF_ERROR(utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto)); + } + // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); @@ -1126,8 +1147,8 @@ INSTANTIATE_UNPACK_TENSOR(UInt2x4) break; template -common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) { - const auto size = narrow(shape.Size()); +static common::Status GetSizeInBytesFromTensorElemCountAndType(size_t elem_count, int32_t element_type, size_t* out) { + const size_t size = elem_count; // Used by CASE_PROTO_TRACE macros switch (element_type) { CASE_PROTO_TRACE(FLOAT, float); CASE_PROTO_TRACE(DOUBLE, double); @@ -1163,6 +1184,12 @@ common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, in return Status::OK(); } +template +static common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) { + const auto size = narrow(shape.Size()); + return GetSizeInBytesFromTensorElemCountAndType(size, element_type, out); +} + template common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto, size_t* out) { TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); @@ -1200,6 +1227,110 @@ common::Status GetSizeInBytesFromTensorTypeProto(const ONNX_NAMESPACE::TypeProto template Status GetSizeInBytesFromTensorTypeProto<0>(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, size_t* out); +common::Status ValidateEmbeddedTensorProtoDataSizeAndShape(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + ORT_RETURN_IF(HasExternalData(tensor_proto), "Expected to validate an embedded (non-external) TensorProto"); + + TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + const int64_t num_elems_signed = tensor_shape.Size(); // returns -1 if any dim is negative. + + ORT_RETURN_IF(num_elems_signed < 0, "Initializer '", tensor_proto.name(), "' has negative dimensions"); + + // Need to ensure num_elements < SIZE_MAX. This would be an issue in 32-bit platforms. + ORT_RETURN_IF(static_cast(num_elems_signed) > std::numeric_limits::max(), + "Initializer '", tensor_proto.name(), "' has a number of elements (", num_elems_signed, + ") that exceeds SIZE_MAX (", std::numeric_limits::max(), ")"); + + const size_t num_elems_unsigned = gsl::narrow_cast(num_elems_signed); + size_t byte_size_from_shape = 0; + ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorElemCountAndType<0>(num_elems_unsigned, tensor_proto.data_type(), + &byte_size_from_shape)); + ORT_RETURN_IF_NOT(byte_size_from_shape <= kMaxEmbeddedInitializerSizeInBytes, + "Initializer '", tensor_proto.name(), "' declares a size of ", byte_size_from_shape, + " bytes which exceeds the ", kMaxEmbeddedInitializerSizeInBytes, + " byte limit for embedded initializer data. Use external data for large initializers."); + + if (HasRawData(tensor_proto)) { + ORT_RETURN_IF_NOT(tensor_proto.raw_data().size() == byte_size_from_shape, + "Initializer '", tensor_proto.name(), "': raw_data size (", tensor_proto.raw_data().size(), + " bytes) does not match expected size from shape and data type (", + byte_size_from_shape, " bytes)"); + } else if (HasString(tensor_proto)) { + ORT_RETURN_IF_NOT(tensor_proto.string_data_size() == num_elems_signed, + "Initializer '", tensor_proto.name(), "': string_data count (", tensor_proto.string_data_size(), + ") does not match expected count from shape (", + num_elems_signed, ")"); + } else { + // Typed data fields. Each data type maps to a specific repeated field in the proto. + int64_t expected_count = 0; + int64_t actual_count = 0; + + switch (tensor_proto.data_type()) { + case TensorProto_DataType_FLOAT: + expected_count = num_elems_signed; + actual_count = tensor_proto.float_data_size(); + break; + case TensorProto_DataType_DOUBLE: + expected_count = num_elems_signed; + actual_count = tensor_proto.double_data_size(); + break; + case TensorProto_DataType_INT64: + expected_count = num_elems_signed; + actual_count = tensor_proto.int64_data_size(); + break; + case TensorProto_DataType_UINT64: + case TensorProto_DataType_UINT32: + expected_count = num_elems_signed; + actual_count = tensor_proto.uint64_data_size(); + break; + case TensorProto_DataType_INT4: + case TensorProto_DataType_UINT4: + expected_count = static_cast(Int4x2::CalcNumInt4Pairs(num_elems_unsigned)); + actual_count = tensor_proto.int32_data_size(); + break; + case TensorProto_DataType_INT2: + case TensorProto_DataType_UINT2: + expected_count = static_cast(Int2x4::CalcNumInt2Quads(num_elems_unsigned)); + actual_count = tensor_proto.int32_data_size(); + break; +#if !defined(DISABLE_FLOAT4_TYPES) + case TensorProto_DataType_FLOAT4E2M1: + expected_count = static_cast(Float4E2M1x2::CalcNumFloat4Pairs(num_elems_unsigned)); + actual_count = tensor_proto.int32_data_size(); + break; +#endif + case TensorProto_DataType_BOOL: + case TensorProto_DataType_UINT8: + case TensorProto_DataType_INT8: + case TensorProto_DataType_UINT16: + case TensorProto_DataType_INT16: + case TensorProto_DataType_FLOAT16: + case TensorProto_DataType_BFLOAT16: +#if !defined(DISABLE_FLOAT8_TYPES) + case TensorProto_DataType_FLOAT8E4M3FN: + case TensorProto_DataType_FLOAT8E4M3FNUZ: + case TensorProto_DataType_FLOAT8E5M2: + case TensorProto_DataType_FLOAT8E5M2FNUZ: +#endif + case TensorProto_DataType_INT32: + // BOOL, INT8, UINT8, INT16, UINT16, FLOAT16, BFLOAT16, INT32, FLOAT8* all use int32_data + expected_count = num_elems_signed; + actual_count = tensor_proto.int32_data_size(); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unhandled TensorProto_DataType (", tensor_proto.data_type(), + ") in ValidateEmbeddedTensorProtoDataSizeAndShape()"); + } + + ORT_RETURN_IF_NOT(actual_count == expected_count, + "Initializer '", tensor_proto.name(), + "': data field count (", actual_count, + ") does not match expected count from shape (", + expected_count, ")"); + } + + return Status::OK(); +} + TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto) { const auto& dims = tensor_shape_proto.dim(); TensorShapeVector tensor_shape_vec(static_cast(dims.size())); @@ -1572,6 +1703,12 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa common::Status CreateTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor) { ORT_RETURN_IF_NOT(utils::HasDataType(tensor_proto), "Initializer must have a datatype"); + + // Validate data consistency and enforce size limits before allocating memory. + if (!utils::HasExternalData(tensor_proto)) { + ORT_RETURN_IF_ERROR(ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto)); + } + auto proto_data_type = tensor_proto.data_type(); auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 941cd9af34b61..f3f33a32b8076 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -166,6 +166,12 @@ common::Status CreateTensorFromTensorProto(const Env& env, const std::filesystem /// in shape inferencing, it is cheaper to inline them. constexpr const size_t kSmallTensorExternalDataThreshold = 127; // 127 bytes +/// Max in-memory tensor size (from shape × dtype) allowed for embedded (non-external) initializers. +/// This is an allocation guard to prevent a malicious model from triggering excessive memory allocation. +/// 2 GiB is chosen as a practical upper bound: valid ONNX protobuf messages cannot exceed ~2 GiB of serialized data, +/// so any embedded initializer whose in-memory representation exceeds this is highly suspect. +constexpr const size_t kMaxEmbeddedInitializerSizeInBytes = size_t{2} * 1024 * 1024 * 1024; // 2 GiB + /** * @brief Creates a TensorProto from a Tensor. * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. @@ -198,6 +204,13 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& template Status GetSizeInBytesFromTensorTypeProto(const ONNX_NAMESPACE::TypeProto_Tensor& tensor_proto, size_t* out); +/// Validates that the size of the actual data content in a non-external TensorProto is consistent with its +/// declared shape and data type. This prevents allocating memory based on a maliciously large +/// declared shape when the actual data is absent or much smaller. +/// The caller must ensure that the TensorProto does not use external data; if it does, this function will +/// return an error status. +common::Status ValidateEmbeddedTensorProtoDataSizeAndShape(const ONNX_NAMESPACE::TensorProto& tensor_proto); + /** Special marker used to indicate an existing memory buffer contains the TensorProto external data. If the 'location' field of the external data info is set to this marker, the 'offset' field should contain the diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index c9b61a7a39632..424d6cbac743c 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -606,5 +606,242 @@ TEST_F(PathValidationTest, ValidateExternalDataPathWithSymlinkOutside) { ASSERT_FALSE(utils::ValidateExternalDataPath(base_dir_, "outside_link.bin").IsOK()); } +// Tests for ValidateEmbeddedTensorProtoDataSizeAndShape and embedded initializer size limits + +TEST(TensorProtoDataSizeShapeValidationTest, ValidTensorProtoWithRawData) { + // A valid float tensor with 4 elements and matching raw_data + TensorProto tensor_proto; + tensor_proto.set_name("valid_raw"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(2); + tensor_proto.add_dims(2); + // 4 floats = 16 bytes + std::string raw(16, '\0'); + tensor_proto.set_raw_data(raw); + + ASSERT_STATUS_OK(utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto)); +} + +TEST(TensorProtoDataSizeShapeValidationTest, ValidTensorProtoWithTypedData) { + // A valid float tensor with typed float_data + TensorProto tensor_proto; + tensor_proto.set_name("valid_typed"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(3); + tensor_proto.add_float_data(1.0f); + tensor_proto.add_float_data(2.0f); + tensor_proto.add_float_data(3.0f); + + ASSERT_STATUS_OK(utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto)); +} + +TEST(TensorProtoDataSizeShapeValidationTest, ValidZeroElementTensor) { + // A valid zero-element tensor (one dim is 0) + TensorProto tensor_proto; + tensor_proto.set_name("zero_elem"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(0); + tensor_proto.add_dims(5); + + ASSERT_STATUS_OK(utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto)); +} + +TEST(TensorProtoDataSizeShapeValidationTest, LargeDimsNoDataRejected) { + // Malicious: large dims but no data at all + TensorProto tensor_proto; + tensor_proto.set_name("malicious_no_data"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(10000); + tensor_proto.add_dims(10000); + // No raw_data or float_data set + + auto status = utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("does not match expected count from shape")); +} + +TEST(TensorProtoDataSizeShapeValidationTest, LargeDimsSmallRawDataRejected) { + // Malicious: large dims with tiny raw_data + TensorProto tensor_proto; + tensor_proto.set_name("malicious_small_raw"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(10000); + tensor_proto.add_dims(10000); + // Only 4 bytes of raw data (1 float), but shape says 100M elements + std::string raw(4, '\0'); + tensor_proto.set_raw_data(raw); + + auto status = utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("does not match expected size from shape and data type")); +} + +TEST(TensorProtoDataSizeShapeValidationTest, LargeDimsSmallTypedDataRejected) { + // Malicious: large dims with just a few typed data elements + TensorProto tensor_proto; + tensor_proto.set_name("malicious_small_typed"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(10000); + tensor_proto.add_dims(10000); + tensor_proto.add_float_data(1.0f); + + auto status = utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("does not match expected count from shape")); +} + +TEST(TensorProtoDataSizeShapeValidationTest, EmbeddedInitializerExceeding2GiBRejected) { + // A tensor whose declared shape exceeds 2 GiB should be rejected by TensorProtoToOrtValue and + // CreateTensorFromTensorProto. + TensorProto tensor_proto; + tensor_proto.set_name("too_large"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + // 536870913 floats * 4 bytes = 2147483652 bytes > 2 GiB + tensor_proto.add_dims(536870913); + // No data — the 2 GiB check should trigger before the consistency check + + // Test call to TensorProtoToOrtValue + { + OrtValue ort_value; + auto status = utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path{}, + tensor_proto, CPUAllocator::DefaultInstance(), ort_value); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("exceeds the 2147483648 byte limit")); + } + + // Test call to CreateTensorFromTensorProto + { + Tensor tensor; + auto status = utils::CreateTensorFromTensorProto(Env::Default(), std::filesystem::path{}, + tensor_proto, tensor); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("exceeds the 2147483648 byte limit")); + } +} + +TEST(TensorProtoDataSizeShapeValidationTest, ValidStringTensorProto) { + // A valid string tensor with matching string_data + TensorProto tensor_proto; + tensor_proto.set_name("valid_string"); + tensor_proto.set_data_type(TensorProto_DataType_STRING); + tensor_proto.add_dims(2); + tensor_proto.add_string_data("hello"); + tensor_proto.add_string_data("world"); + + ASSERT_STATUS_OK(utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto)); +} + +TEST(TensorProtoDataSizeShapeValidationTest, StringTensorWithMismatchedCountRejected) { + TensorProto tensor_proto; + tensor_proto.set_name("bad_string"); + tensor_proto.set_data_type(TensorProto_DataType_STRING); + tensor_proto.add_dims(100); + tensor_proto.add_string_data("only_one"); + + auto status = utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("does not match expected count from shape")); +} + +TEST(TensorProtoDataSizeShapeValidationTest, NegativeDimsRejected) { + TensorProto tensor_proto; + tensor_proto.set_name("negative_dims"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(-1); + tensor_proto.add_dims(10); + + auto status = utils::ValidateEmbeddedTensorProtoDataSizeAndShape(tensor_proto); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("negative dimensions")); +} + +#if !defined(__wasm__) +// Tests for external data file size validation in ReadExternalDataForTensor. +// These verify that the file size is checked before allocating memory for the tensor. + +TEST(TensorProtoDataSizeShapeValidationTest, ExternalDataFileTooSmallForDeclaredShape) { + // Create a small external data file with 4 floats (16 bytes) + std::basic_string filename(ORT_TSTR("ext_small_XXXXXX")); + FILE* fp; + CreateTestFile(fp, filename); + const float small_data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + ASSERT_EQ(sizeof(small_data), fwrite(small_data, 1, sizeof(small_data), fp)); + ASSERT_EQ(0, fclose(fp)); + std::unique_ptr file_deleter( + const_cast(filename.c_str()), DeleteFileFromDisk); + + // Declare a tensor with 1000 floats (4000 bytes) but the file only has 16 bytes + TensorProto tensor_proto; + tensor_proto.set_name("malicious_external"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(1000); + tensor_proto.set_data_location(TensorProto_DataLocation_EXTERNAL); + auto* location = tensor_proto.add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(filename)); + + std::vector unpacked_tensor; + auto status = utils::UnpackInitializerData(tensor_proto, std::filesystem::path{}, unpacked_tensor); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("out of bounds")); +} + +TEST(TensorProtoDataSizeShapeValidationTest, ExternalDataOffsetPushesReadPastEndOfFile) { + // Create an external data file with 4 floats (16 bytes) + std::basic_string filename(ORT_TSTR("ext_offset_XXXXXX")); + FILE* fp; + CreateTestFile(fp, filename); + const float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + ASSERT_EQ(sizeof(data), fwrite(data, 1, sizeof(data), fp)); + ASSERT_EQ(0, fclose(fp)); + std::unique_ptr file_deleter( + const_cast(filename.c_str()), DeleteFileFromDisk); + + // Declare a tensor with 4 floats (16 bytes) but at offset 8, so read needs bytes [8..24) but file is only 16 bytes + TensorProto tensor_proto; + tensor_proto.set_name("offset_external"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(4); + tensor_proto.set_data_location(TensorProto_DataLocation_EXTERNAL); + auto* location = tensor_proto.add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(filename)); + auto* offset = tensor_proto.add_external_data(); + offset->set_key("offset"); + offset->set_value("8"); + + std::vector unpacked_tensor; + auto status = utils::UnpackInitializerData(tensor_proto, std::filesystem::path{}, unpacked_tensor); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("out of bounds")); +} + +TEST(TensorProtoDataSizeShapeValidationTest, ExternalDataValidFileSizeSucceeds) { + // Create an external data file with exactly 4 floats (16 bytes) + std::basic_string filename(ORT_TSTR("ext_valid_XXXXXX")); + FILE* fp; + CreateTestFile(fp, filename); + const float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + ASSERT_EQ(sizeof(data), fwrite(data, 1, sizeof(data), fp)); + ASSERT_EQ(0, fclose(fp)); + std::unique_ptr file_deleter( + const_cast(filename.c_str()), DeleteFileFromDisk); + + // Declare a tensor with matching shape (4 floats = 16 bytes) + TensorProto tensor_proto; + tensor_proto.set_name("valid_external"); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_dims(4); + tensor_proto.set_data_location(TensorProto_DataLocation_EXTERNAL); + auto* location = tensor_proto.add_external_data(); + location->set_key("location"); + location->set_value(ToUTF8String(filename)); + + std::vector unpacked_tensor; + ASSERT_STATUS_OK(utils::UnpackInitializerData(tensor_proto, std::filesystem::path{}, unpacked_tensor)); + ASSERT_EQ(unpacked_tensor.size(), sizeof(data)); +} +#endif // !defined(__wasm__) + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index bf632d0b3bc40..d1d1dd1c321af 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -60,6 +60,20 @@ TEST(DequantizeLinearOpTest, Int8_Large) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); } +TEST(DequantizeLinearOpTest, Int4_LargeInitializerInput) { + OpTester test("DequantizeLinear", 21); + std::vector dims{1024}; + + std::vector x_vals(Int4x2::CalcNumInt4Pairs(static_cast(dims[0])), Int4x2{}); + std::vector expected_y_vals(static_cast(dims[0]), 0.f); + + test.AddInput("x", dims, x_vals, true); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {Int4x2(0, 0)}); + test.AddOutput("y", dims, expected_y_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // scalar zero & scale with int4 TEST(DequantizeLinearOpTest, Int4) { OpTester test("DequantizeLinear", 21); @@ -131,6 +145,18 @@ TEST(DequantizeLinearOpTest, Int2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(DequantizeLinearOpTest, Int2_LargeInitializerInput) { + OpTester test("DequantizeLinear", 25); + std::vector dims{4096}; + std::vector x_vals(Int2x4::CalcNumInt2Quads(static_cast(dims[0])), Int2x4()); + + test.AddInput("x", dims, x_vals, true); + test.AddInput("x_scale", {}, {1.0f}); + test.AddInput("x_zero_point", {}, {Int2x4()}); + test.AddOutput("y", dims, std::vector(static_cast(dims[0]), 0.0f)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // scalar scale with int2 (no zero point) TEST(DequantizeLinearOpTest, Int2NoZeroPoint) { OpTester test("DequantizeLinear", 25); diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index 2e0459103a7c9..4d1b2a210599a 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -74,7 +74,9 @@ void BaseTester::AddInitializers(onnxruntime::Graph& graph) { tensor_proto.add_string_data(string_data[i]); } } else { - auto buffer_size = tensor.DataType()->Size() * shape.Size(); + // Note: need to use Tensor::CalculateTensorStorageSize (instead of shape.Size() * elem_size) to properly + // calculate the storage size for sub-byte types (e.g., Int4 or Int2) + auto buffer_size = Tensor::CalculateTensorStorageSize(tensor.DataType(), shape); utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), buffer_size); }