diff --git a/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts b/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts index 52ecd07cb7f92..83c96cb8223ed 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/quantize-linear.ts @@ -31,9 +31,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLin if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) { throw new Error('x and x-zero-point must have the same data type.'); } - if (inputs[0].dataType === DataType.int32 && inputs.length > 2) { - throw new Error('In the case of dequantizing int32 there is no zero point.'); - } if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) { throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.'); } diff --git a/js/web/test/data/ops/dequantizelinear.jsonc b/js/web/test/data/ops/dequantizelinear.jsonc index 2dc04d11f2889..312f229d6f329 100644 --- a/js/web/test/data/ops/dequantizelinear.jsonc +++ b/js/web/test/data/ops/dequantizelinear.jsonc @@ -6,7 +6,7 @@ "attributes": [], "cases": [ { - "name": "T[1]", + "name": "uint8 per-tensor with zero point", "inputs": [ { "data": [1, 2, 3, 4], @@ -41,7 +41,7 @@ "attributes": [], "cases": [ { - "name": "T[2]", + "name": "int32 per-tensor no zero point", "inputs": [ { "data": [1, 2, 3, 4], @@ -64,6 +64,41 @@ } ] }, + { + "name": "dequantizelinear", + "operator": "DequantizeLinear", + "opset": { "domain": "", "version": 10 }, + "attributes": [], + "cases": [ + { + "name": "int32 per-tensor with zero point", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "int32" + }, + { + "data": [0.1], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0.0, 0.1, 0.2, 0.3], + "dims": [4], + "type": "float32" + } + ] + } + ] + }, { "name": "dequantizelinear", "operator": "DequantizeLinear", @@ -77,7 +112,7 @@ ], "cases": [ { - "name": "T[3]", + "name": "uint8 2D per-axis scalar scale with zero point", "inputs": [ { "data": [1, 2, 3, 4], @@ -118,7 +153,7 @@ ], "cases": [ { - "name": "T[4]", + "name": "int32 2D per-axis scalar scale no zero point", "inputs": [ { "data": [1, 2, 3, 4], @@ -154,7 +189,48 @@ ], "cases": [ { - "name": "T[5]", + "name": "int32 2D per-axis scalar scale with zero point", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "int32" + }, + { + "data": [0.1], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0.0, 0.1, 0.2, 0.3], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "dequantizelinear", + "operator": "DequantizeLinear", + "opset": { "domain": "", "version": 13 }, + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "uint8 3D per-axis uniform scale with zero point", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8], @@ -195,7 +271,7 @@ ], "cases": [ { - "name": "T[6]", + "name": "uint8 3D per-axis varying scale with zero point", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8], @@ -236,7 +312,7 @@ ], "cases": [ { - "name": "T[7]", + "name": "int32 3D per-axis scalar scale no zero point", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8], @@ -259,6 +335,47 @@ } ] }, + { + "name": "dequantizelinear", + "operator": "DequantizeLinear", + "opset": { "domain": "", "version": 13 }, + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "int32 3D per-axis scalar scale with zero point", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "int32" + }, + { + "data": [0.1], + "dims": [1], + "type": "float32" + }, + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "dequantizelinear", "operator": "DequantizeLinear", @@ -277,7 +394,7 @@ ], "cases": [ { - "name": "T[8]", + "name": "uint8 3D blocked with zero point", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8], @@ -323,7 +440,7 @@ ], "cases": [ { - "name": "T[9]", + "name": "int32 3D blocked no zero point", "inputs": [ { "data": [1, 2, 3, 4, 5, 6, 7, 8], @@ -346,6 +463,52 @@ } ] }, + { + "name": "dequantizelinear block dequantization", + "operator": "DequantizeLinear", + "opset": { "domain": "", "version": 21 }, + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + }, + { + "name": "block_size", + "data": 2, + "type": "int" + } + ], + "cases": [ + { + "name": "int32 3D blocked with zero point", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 2, 2], + "type": "int32" + }, + { + "data": [0.1, 0.2, 0.3, 0.4], + "dims": [2, 1, 2], + "type": "float32" + }, + { + "data": [0, 1, 0, 1], + "dims": [2, 1, 2], + "type": "int32" + } + ], + "outputs": [ + { + "data": [0.1, 0.2, 0.3, 0.6, 1.5, 2.0, 2.1, 2.8], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "dequantizelinear", "operator": "DequantizeLinear", @@ -359,7 +522,7 @@ ], "cases": [ { - "name": "T[3]", + "name": "uint8 2D per-axis scalar scale no zero point", "inputs": [ { "data": [1, 2, 3, 4], diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc index e7736c3f3afac..2cf0f11ce46f2 100644 --- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -79,7 +79,7 @@ Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { if (packed_) { shader.MainFunctionBody() << "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n" - << "let zero_point_input = " << zero_point.GetByOffset("u32(zero_point_index / 4)") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("zero_point_index / 4") << ";\n" << "let zero_point_vec = " << unpack << ";\n" << "let zero_point_value = zero_point_vec[zero_point_index % 4];\n"; } else { @@ -88,16 +88,17 @@ Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let zero_point_value = " << zero_point.GetByOffset("zero_point_index") << ";\n"; } } else { - // BlockedQuantization. The zero-point input shape is same as the input shape except along axis. + // BlockedQuantization. The zero-point input shape is the same as the scale input shape. if (packed_) { shader.MainFunctionBody() - << "let zero_point_offset = " << scale.GetByIndices("scale_indices") << ";\n" - << "let zero_point_input = " << zero_point.GetByOffset("u32(zero_point_offset / 4)") << ";\n" + << "let zero_point_offset = " << scale.IndicesToOffset("scale_indices") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n" << "let zero_point_vec = " << unpack << ";\n" << "let zero_point_value = zero_point_vec[zero_point_offset % 4];\n"; } else { shader.MainFunctionBody() - << "let zero_point_value = " << zero_point.GetByIndices("scale_indices") << ";\n"; + << "let zero_point_offset = " << scale.IndicesToOffset("scale_indices") << ";\n" + << "let zero_point_value = " << zero_point.GetByOffset("zero_point_offset") << ";\n"; } } } else { @@ -145,7 +146,9 @@ Status DequantizeLinear::ComputeInternal(ComputeContext& context) const { program .AddInputs({{x, ProgramTensorMetadataDependency::TypeAndRank, ProgramInput::Flatten, packed ? 4 : input_component}}) .AddInputs({{x_scale, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank, components}) + .AddOutput(use_components + ? ProgramOutput{output_tensor, ProgramTensorMetadataDependency::Rank, ProgramOutput::Flatten, components} + : ProgramOutput{output_tensor, ProgramTensorMetadataDependency::Rank, components}) .SetDispatchGroupSize((x_size / components + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({{static_cast(axis)}}) .AddUniformVariables({{static_cast(block_size_)}})