diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc index e822f8764b63f..0f2e6d725007f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc @@ -21,16 +21,16 @@ ONNX_OPERATOR_KERNEL_EX( BiasAdd); Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const { - const ShaderVariableHelper& input = shader.AddInput("input"); - const ShaderVariableHelper& bias = shader.AddInput("bias"); - const ShaderVariableHelper& residual = shader.AddInput("residual"); - const ShaderVariableHelper& output = shader.AddOutput("output"); + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform); + const ShaderVariableHelper& bias = shader.AddInput("bias", ShaderUsage::UseUniform); + const ShaderVariableHelper& residual = shader.AddInput("residual", ShaderUsage::UseUniform); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << "let value = " << input.GetByOffset("global_idx") + << " let value = " << input.GetByOffset("global_idx") << " + " << bias.GetByOffset("global_idx % uniforms.channels") << " + " << residual.GetByOffset("global_idx") << ";\n" - << output.SetByOffset("global_idx", "value"); + << " " + output.SetByOffset("global_idx", "value"); return Status::OK(); } @@ -47,23 +47,26 @@ Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) co } int64_t channels = input_shape[2]; - int64_t components = GetMaxComponents(channels); - channels /= components; - TensorShape bias_shape = bias->Shape(); if (bias_shape.NumDimensions() != 1 || bias_shape[0] != channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BiasAdd bias should have 1 dimension with size equal to the number of channels."); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BiasAdd bias should have 1 dimension with size equal to the number of channels."); } + int components = GetMaxComponents(channels); + channels /= components; + auto* output = context.Output(0, input_shape); int64_t output_size = output->Shape().Size() / components; BiasAddProgram program{}; - program.AddInputs({{input}, {bias}, {residual}}) - .AddOutput({output}) + program + .AddInputs({{input, ProgramTensorMetadataDependency::None, components}, + {bias, ProgramTensorMetadataDependency::None, components}, + {residual, ProgramTensorMetadataDependency::None, components}}) + .AddOutput({output, ProgramTensorMetadataDependency::None, components}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({{static_cast(output_size)}, - {static_cast(channels)}}); + .AddUniformVariables({{static_cast(output_size)}, {static_cast(channels)}}); return context.RunProgram(program); }