diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 73291e1e93ff1..89f547481b6e4 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -102,6 +102,9 @@ constexpr std::string_view ProgramVariableDataTypeName[] = { "u8x4", // Uint8x4 "u8x8", // Uint8x8 "u8x16", // Uint8x16 + "i8x4", // Int8x4 + "i8x8", // Int8x8 + "i8x16", // Int8x16 }; std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; @@ -129,6 +132,7 @@ int NumberOfComponents(ProgramVariableDataType type) { case ProgramVariableDataType::Float16x4: case ProgramVariableDataType::Boolx4: case ProgramVariableDataType::Uint8x4: + case ProgramVariableDataType::Int8x4: return 4; case ProgramVariableDataType::Uint8x8: return 8; @@ -142,6 +146,10 @@ int NumberOfComponents(ProgramVariableDataType type) { ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { if (component == 1) { switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x4; // shader needs to be aware that only 1 value is valid + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return ProgramVariableDataType::Int8x4; // shader needs to be aware that only 1 value is valid case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return ProgramVariableDataType::Float32; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: @@ -174,6 +182,8 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return ProgramVariableDataType::Uint8x4; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return ProgramVariableDataType::Int8x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return ProgramVariableDataType::Float32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index ea7d8ae5471af..3b0acfa7d0d35 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -197,7 +197,10 @@ enum class ProgramVariableDataType { Boolx4, Uint8x4, Uint8x8, - Uint8x16 + Uint8x16, + Int8x4, + Int8x8, + Int8x16, }; #ifndef NDEBUG std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc new file mode 100644 index 0000000000000..866b1debf6dc8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/util/math.h" +#include "core/providers/webgpu/quantization/quantize_linear.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace webgpu { + +Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& scale = shader.AddInput("scale", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() + << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"; + + // Get x input + if (packed_) { + std::string unpack = (signed_) ? "unpack4xI8(x)" : "unpack4xU8(x)"; + if (output.NumComponents() == 1) { + shader.MainFunctionBody() + << "let x = " << x.GetByOffset("global_idx / 4") << ";\n" + << "let x_vec = " << unpack << ";\n" + << "let x_value = x_vec[global_idx % 4];\n"; + } else { + shader.MainFunctionBody() + << "let x = " << x.GetByOffset("global_idx") << ";\n" + << "let x_vec = " << unpack << ";\n" + << "let x_value = x_vec;\n"; + } + } else { + shader.MainFunctionBody() + << "let x_value = " << x.GetByOffset("global_idx") << ";\n"; + } + + // Get scaler + if (per_layer_) { + // scale input is a scalar () + shader.MainFunctionBody() + << "let scale_value = " << scale.GetByOffset("0") << ";\n"; + } else if (per_axis_) { + shader.MainFunctionBody() + << "let scale_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n" + << "let scale_value = " << scale.GetByOffset("scale_index") << ";\n"; + } else { + // Block quantization. Scale input rank is same as input/output rank. + shader.MainFunctionBody() + << "var scale_indices: scale_indices_t = output_indices;\n" + << "let index = " << scale.IndicesGet("scale_indices", "uniforms.axis") << "/ uniforms.block_size;\n" + << scale.IndicesSet("scale_indices", "uniforms.axis", "index") << ";\n" + << "let scale_value = " << scale.GetByIndices("scale_indices") << ";\n"; + } + + // Get zero-point + if (has_zeropoint_) { + const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + std::string unpack = (signed_) ? "unpack4xI8(zero_point_input)" : "unpack4xU8(zero_point_input)"; + if (per_layer_) { + // zero-point input is a scalar + if (packed_) { + shader.MainFunctionBody() + << "let zero_point_input = " << zero_point.GetByOffset("0") << ";\n" + << "let zero_point_vec = " << unpack << ";\n" + << "let zero_point_value = zero_point_vec[0];\n"; + } else { + shader.MainFunctionBody() + << "let zero_point_value = " << zero_point.GetByOffset("0") << ";\n"; + } + } else if (per_axis_) { + // zero-point input is a 1D tensor + if (packed_) { + shader.MainFunctionBody() + << "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("zero_point_index / 4") << ";\n" + << "let zero_point_vec = " << unpack << ";\n" + << "let zero_point_value = zero_point_vec[zero_point_index % 4];\n"; + } else { + shader.MainFunctionBody() + << "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n" + << "let zero_point_value = " << zero_point.GetByOffset("zero_point_index") << ";\n"; + } + } else { + // BlockedQuantization. The zero-point input shape is same as the input shape except along axis. + if (packed_) { + shader.MainFunctionBody() + << "let zero_point_offset = " << scale.GetByIndices("scale_indices") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n" + << "let zero_point_vec = " << unpack << ";\n" + << "let zero_point_value = zero_point_vec[zero_point_offset % 4];\n"; + } else { + shader.MainFunctionBody() + << "let zero_point_value = " << zero_point.GetByIndices("scale_indices") << ";\n"; + } + } + } else { + shader.MainFunctionBody() + << "let zero_point_value = input_element_t(0);\n"; + } + + // compute and write output + shader.MainFunctionBody() + << output.SetByOffset("global_idx", "(output_value_t(x_value) - scale_value_t(zero_point_value)) * scale_value"); + + return Status::OK(); +} + +Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* x_scale = context.Input(1); + const auto* x_zeropoint = context.Input(2); + const auto x_shape = x->Shape(); + int64_t x_size = x_shape.Size(); + auto* output_tensor = context.Output(0, x_shape); + int64_t x_scale_rank = x_scale->Shape().NumDimensions(); + + bool packed = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + bool is_signed = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + int64_t axis = (axis_ >= 0) ? axis_ : axis_ + x_shape.NumDimensions(); + + int max_components = GetMaxComponents(x_size); + if (max_components != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DequantizeLinear: components must be 4, but got ", max_components); + } + + // scaler - single scaler for all elements + bool per_layer = x_scale_rank == 0 || (x_scale_rank == 1 && x_scale->Shape()[0] == 1); + + // 1D tensor - 1 scaler for per axis + bool per_axis = per_layer == false && x_scale_rank == 1; + + bool use_components = per_layer && (!packed || max_components == 4); + int components = use_components ? max_components : 1; + int input_component = use_components && !packed ? max_components : 1; + + DequantizeLinearProgram program{packed, is_signed, per_layer, per_axis, x_zeropoint != nullptr}; + + program + .AddInputs({{x, ProgramTensorMetadataDependency::TypeAndRank, input_component}}) + .AddInputs({{x_scale, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None, components}) + .SetDispatchGroupSize((x_size / components + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(axis)}}) + .AddUniformVariables({{static_cast(block_size_)}}) + .AddUniformVariables({{static_cast(x_size / components)}}) + .CacheHint(std::to_string(axis), std::to_string(is_signed), std::to_string(per_layer), std::to_string(per_axis), std::to_string(block_size_)); + + if (x_zeropoint != nullptr) { + program.AddInputs({{x_zeropoint, ProgramTensorMetadataDependency::TypeAndRank}}); + } + + return context.RunProgram(program); +} + +namespace { +const std::vector& DequantizeLinearConstraints() { + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DequantizeLinear, + kOnnxDomain, + 10, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DequantizeLinearConstraints()), + DequantizeLinear); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DequantizeLinear, + kOnnxDomain, + 13, 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DequantizeLinearConstraints()), + DequantizeLinear); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DequantizeLinear, + kOnnxDomain, + 19, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DequantizeLinearConstraints()) + .TypeConstraint("T2", WebGpuSupportedFloatTypes()), + DequantizeLinear); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DequantizeLinear, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DequantizeLinearConstraints()) + .TypeConstraint("T2", WebGpuSupportedFloatTypes()), + DequantizeLinear); + +ONNX_OPERATOR_KERNEL_EX( + DequantizeLinear, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DequantizeLinearConstraints()) + .TypeConstraint("T2", WebGpuSupportedFloatTypes()), + DequantizeLinear); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.h b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.h new file mode 100644 index 0000000000000..95614998017e9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class DequantizeLinearProgram final : public Program { + public: + DequantizeLinearProgram(const bool packed, const bool issigned, const bool per_layer, + const bool per_axis, bool has_zeropoint) : Program{"DequantizeLinear"}, + packed_{packed}, + signed_{issigned}, + per_layer_{per_layer}, + per_axis_{per_axis}, + has_zeropoint_{has_zeropoint} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"axis", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + bool packed_; + bool signed_; + bool per_layer_; + bool per_axis_; + bool has_zeropoint_; +}; + +class DequantizeLinear final : public WebGpuKernel { + public: + DequantizeLinear(const OpKernelInfo& info) : WebGpuKernel(info) { + axis_ = info.GetAttrOrDefault("axis", 1); + block_size_ = info.GetAttrOrDefault("block_size", 0); + output_dtype_ = info.GetAttrOrDefault("output_dtype", 0); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; + int64_t block_size_; + int64_t output_dtype_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index db14cb88d1963..bac360c4c270e 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -168,6 +168,12 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va var_type == ProgramVariableDataType::Uint8x16, "Unexpected program variable type ", int(var_type), " for uint8 tensor"); break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int8x4 || + var_type == ProgramVariableDataType::Int8x8 || + var_type == ProgramVariableDataType::Int8x16, + "Unexpected program variable type ", int(var_type), " for int8 tensor"); + break; default: ORT_RETURN_IF(true, "Unsupported data type: ", element_type); // todo: add int4/uint4 diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f8e1e0b3b8d2b..502d03c2c2dd8 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -32,6 +32,7 @@ constexpr static const std::string_view STORAGE_TYPE_ARRAY[] = { "u32", // Uint8x4 "vec2", // Uint8x8 "vec4", // Uint8x16 + "u32", // Int8x4 }; constexpr static const auto STORAGE_TYPE = details::_to_std_array(STORAGE_TYPE_ARRAY); @@ -54,6 +55,7 @@ constexpr static const std::string_view VALUE_TYPE_ARRAY[] = { "u32", // Uint8x4 (u32 as 4 elements of uint8) "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) + "i32", // Int8x4 }; constexpr static const auto VALUE_TYPE = details::_to_std_array(VALUE_TYPE_ARRAY); @@ -76,6 +78,9 @@ constexpr static const std::string_view ELEMENT_TYPE_ARRAY[] = { "u32", // Uint8x4 "u32", // Uint8x8 "u32", // Uint8x16 + "i32", // Int8x4 + "i32", // Int8x8 + "i32", // Int8x16 }; constexpr static const auto ELEMENT_TYPE = details::_to_std_array(ELEMENT_TYPE_ARRAY); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 2427bf62cc658..4e6de6547665f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -387,18 +387,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int32_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, DequantizeLinear); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, DequantizeLinear); std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -715,20 +708,15 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + KERNEL_CREATE_INFO_VERSIONED(10, 12, DequantizeLinear), + KERNEL_CREATE_INFO_VERSIONED(13, 18, DequantizeLinear), + KERNEL_CREATE_INFO_VERSIONED(19, 20, DequantizeLinear), + KERNEL_CREATE_INFO_VERSIONED(21, 22, DequantizeLinear), + KERNEL_CREATE_INFO(23, DequantizeLinear), + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 51aae0cfd4adf..4e7a6356a5129 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -156,7 +156,8 @@ TEST(DequantizeLinearOpTest, Scalar) { test.AddInput("x_zero_point", {}, {-10}); test.AddOutput("y", {}, {220.0f}); // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 0. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + // Disable WebGPU EP due to error: needs at least component size 4 + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); } // dequantize with scalar data @@ -167,7 +168,8 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) { test.AddInput("x_zero_point", {}, {-10}); test.AddOutput("y", {}, {MLFloat16(220.0f)}); // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 0. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + // Disable WebGPU EP due to error: needs at least component size 4 + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); } // dequantize without zero point @@ -176,6 +178,45 @@ TEST(DequantizeLinearOpTest, Without_Zero_Point) { test.AddInput("x", {}, {100}); test.AddInput("x_scale", {}, {2.0f}); test.AddOutput("y", {}, {200.0f}); + // No DQ allowed without corresponding Q. Skip since TRT10 + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider}); +} + +// dequantize without zero point int8 (testing 8 elements for webgpu) +TEST(DequantizeLinearOpTest, No_Zero_Point_int8) { + OpTester test("DequantizeLinear", 10); + test.AddInput("x", {1, 8}, {-10, 50, 100, 120, -9, 49, 99, 119}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddOutput("y", {1, 8}, {-20.0f, 100.0f, 200.0f, 240.0f, -18.0f, 98.0f, 198.0f, 238.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // No DQ allowed without corresponding Q. Skip since TRT10 +} + +// dequantize without zero point uint8 (testing 8 elements for webgpu) +TEST(DequantizeLinearOpTest, No_Zero_Point_uint8) { + OpTester test("DequantizeLinear", 10); + test.AddInput("x", {1, 8}, {10, 50, 100, 180, 9, 49, 99, 179}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddOutput("y", {1, 8}, {20.0f, 100.0f, 200.0f, 360.0f, 18.0f, 98.0f, 198.0f, 358.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // No DQ allowed without corresponding Q. Skip since TRT10 +} + +// dequantize zero point int8 (testing 8 elements for webgpu) +TEST(DequantizeLinearOpTest, Zero_Point_int8) { + OpTester test("DequantizeLinear", 10); + test.AddInput("x", {1, 8}, {-10, 50, 100, 120, -9, 49, 99, 119}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("zero_point", {}, {-10}); + test.AddOutput("y", {1, 8}, {0.0f, 120.0f, 220.0f, 260.0f, 2.0f, 118.0f, 218.0f, 258.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // No DQ allowed without corresponding Q. Skip since TRT10 +} + +// dequantize zero point uint8 (testing 8 elements for webgpu) +TEST(DequantizeLinearOpTest, Zero_Point_uint8) { + OpTester test("DequantizeLinear", 10); + test.AddInput("x", {1, 8}, {10, 50, 100, 180, 9, 49, 99, 119}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("zero_point", {}, {10}); + test.AddOutput("y", {1, 8}, {0.0f, 80.0f, 180.0f, 340.0f, -2.0f, 78.0f, 178.0f, 218.0f}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // No DQ allowed without corresponding Q. Skip since TRT10 }