diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 697428e1ce140..c2a8896b84a7e 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/nn/conv.h" #include "core/providers/webgpu/nn/conv2d_mm.h" +#include "core/providers/webgpu/nn/conv3d_naive.h" #include "core/providers/webgpu/nn/im2col_matmul.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -80,8 +81,42 @@ Status Conv::ComputeInternal(ComputeContext& context std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); auto rank = input_shape.NumDimensions(); const InlinedVector perm = {2, 3, 1, 0}; - if (rank > 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported."); + if (rank > 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d, Conv2d, and Conv3d are supported."); + } else if (rank == 5) { + // Conv3D - use naive per-element shader (matching JS implementation) + if (conv_attrs_.group != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Conv3D does not support grouped convolution (group=", conv_attrs_.group, ")."); + } + const auto output_size = static_cast(output_shape.Size()); + const auto kernel_depth = static_cast(kernel_shape[2]); + const auto kernel_height = static_cast(kernel_shape[3]); + const auto kernel_width = static_cast(kernel_shape[4]); + // pads: head padding values for each spatial dim (front, top, left) + std::vector pads_3d{pads[0], pads[1], pads[2]}; + // Extract spatial dims and channels for explicit uniforms + const auto x_depth = static_cast(input_shape[is_channels_last ? 1 : 2]); + const auto x_height = static_cast(input_shape[is_channels_last ? 2 : 3]); + const auto x_width = static_cast(input_shape[is_channels_last ? 3 : 4]); + const auto x_channels = static_cast(input_shape[is_channels_last ? 4 : 1]); + Conv3DNaiveProgram program(activation_, has_bias, is_channels_last); + program.CacheHint(activation_.ToString(), std::to_string(is_channels_last)) + .AddInput({input, ProgramTensorMetadataDependency::TypeAndRank, input_shape, 1}) + .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, 1}) + .AddUniformVariables({{output_size}, + {std::vector{kernel_depth, kernel_height, kernel_width}}, + {pads_3d}, + {strides}, + {dilations}, + {std::vector{x_depth, x_height, x_width}}, + {x_channels}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, bias->Shape(), 1}); + } + return context.RunProgram(program); } else if (rank == 4) { // Conv2D } else if (rank == 3) { diff --git a/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc new file mode 100644 index 0000000000000..76895e684eeab --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.cc @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/providers/webgpu/nn/conv3d_naive.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" + +namespace onnxruntime { +namespace webgpu { + +Status Conv3DNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + + std::string apply_activation = GetActivationSnippet(activation_, "x_value_t", "x_element_t"); + + // Helper functions to access x and w by 5D indices + shader.AdditionalImplementation() + << "fn getX(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> x_value_t {\n" + << " let aIndices = x_indices_t(d0, d1, d2, d3, d4);\n" + << " return " << x.GetByIndices("aIndices") << ";\n" + << "}\n" + << "fn getW(d0 : u32, d1 : u32, d2 : u32, d3 : u32, d4 : u32) -> x_value_t {\n" + << " let aIndices = w_indices_t(d0, d1, d2, d3, d4);\n" + << " return " << w.GetByIndices("aIndices") << ";\n" + << "}\n"; + + // Spatial dimensions and channels are passed as explicit uniforms + // to avoid rank-5 shape packing issues (array,2> vs vec4). + shader.MainFunctionBody() + << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch = output_indices[0];\n" + << "let d2 = " << output.IndicesGet("output_indices", is_channels_last_ ? "4" : "1") << ";\n" + << "let xFRCCorner = vec3(" << output.IndicesGet("output_indices", is_channels_last_ ? "1" : "2") << ", " + << output.IndicesGet("output_indices", is_channels_last_ ? "2" : "3") << ", " + << output.IndicesGet("output_indices", is_channels_last_ ? "3" : "4") << ") * uniforms.strides - uniforms.pads;\n" + << "let xFCorner = xFRCCorner.x;\n" + << "let xRCorner = xFRCCorner.y;\n" + << "let xCCorner = xFRCCorner.z;\n" + << "let xDepth = uniforms.x_spatial[0];\n" + << "let xHeight = uniforms.x_spatial[1];\n" + << "let xWidth = uniforms.x_spatial[2];\n" + << "let xChannels = uniforms.x_channels;\n" + << "let inputChannelsNearestVec4 = (xChannels / 4u) * 4u;\n" + << "let inputChannelsVec4Remainder = xChannels % 4u;\n" + << "\n" + << "var value = x_value_t(0);\n" + << "for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {\n" + << " let xF = xFCorner + wF * uniforms.dilations[0];\n" + << " if (xF >= xDepth) {\n" + << " continue;\n" + << " }\n" + << " for (var wR = 0u; wR < uniforms.filter_dims[1]; wR++) {\n" + << " let xR = xRCorner + wR * uniforms.dilations[1];\n" + << " if (xR >= xHeight) {\n" + << " continue;\n" + << " }\n" + << " for (var wC = 0u; wC < uniforms.filter_dims[2]; wC++) {\n" + << " let xC = xCCorner + wC * uniforms.dilations[2];\n" + << " if (xC >= xWidth) {\n" + << " continue;\n" + << " }\n" + << " for (var d1 = 0u; d1 < inputChannelsNearestVec4; d1 += 4u) {\n"; + + // vec4 dot product accumulation over input channels + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec4(\n" + << " getX(batch, xF, xR, xC, d1),\n" + << " getX(batch, xF, xR, xC, d1 + 1u),\n" + << " getX(batch, xF, xR, xC, d1 + 2u),\n" + << " getX(batch, xF, xR, xC, d1 + 3u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec4(\n" + << " getX(batch, d1, xF, xR, xC),\n" + << " getX(batch, d1 + 1u, xF, xR, xC),\n" + << " getX(batch, d1 + 2u, xF, xR, xC),\n" + << " getX(batch, d1 + 3u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec4(\n" + << " getW(d2, d1, wF, wR, wC),\n" + << " getW(d2, d1 + 1u, wF, wR, wC),\n" + << " getW(d2, d1 + 2u, wF, wR, wC),\n" + << " getW(d2, d1 + 3u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " }\n"; + + // Handle remainder channels (1, 2, or 3) + shader.MainFunctionBody() + << " if (inputChannelsVec4Remainder == 1u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " value += getX(batch, xF, xR, xC, inputChannelsNearestVec4)\n" + << " * getW(d2, inputChannelsNearestVec4, wF, wR, wC);\n"; + } else { + shader.MainFunctionBody() + << " value += getX(batch, inputChannelsNearestVec4, xF, xR, xC)\n" + << " * getW(d2, inputChannelsNearestVec4, wF, wR, wC);\n"; + } + shader.MainFunctionBody() + << " } else if (inputChannelsVec4Remainder == 2u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec2(\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 1u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec2(\n" + << " getX(batch, inputChannelsNearestVec4, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 1u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec2(\n" + << " getW(d2, inputChannelsNearestVec4, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 1u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " } else if (inputChannelsVec4Remainder == 3u) {\n"; + if (is_channels_last_) { + shader.MainFunctionBody() + << " let xValues = vec3(\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 1u),\n" + << " getX(batch, xF, xR, xC, inputChannelsNearestVec4 + 2u));\n"; + } else { + shader.MainFunctionBody() + << " let xValues = vec3(\n" + << " getX(batch, inputChannelsNearestVec4, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 1u, xF, xR, xC),\n" + << " getX(batch, inputChannelsNearestVec4 + 2u, xF, xR, xC));\n"; + } + shader.MainFunctionBody() + << " let wValues = vec3(\n" + << " getW(d2, inputChannelsNearestVec4, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 1u, wF, wR, wC),\n" + << " getW(d2, inputChannelsNearestVec4 + 2u, wF, wR, wC));\n" + << " value += x_value_t(dot(xValues, wValues));\n" + << " }\n" + << " }\n" + << " }\n" + << "}\n"; + + // Apply bias + if (has_bias_) { + const auto& b = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.MainFunctionBody() << "value = value + " << b.GetByIndices("d2") << ";\n"; + } + + // Apply activation + shader.MainFunctionBody() << apply_activation << "\n"; + + // Write output + shader.MainFunctionBody() << output.SetByOffset("global_idx", "value"); + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h new file mode 100644 index 0000000000000..25ae449a7d02c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv3d_naive.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class Conv3DNaiveProgram final : public Program { + public: + Conv3DNaiveProgram(const Activation& activation, bool has_bias, bool is_channels_last) + : Program("Conv3DNaive"), activation_(activation), has_bias_(has_bias), is_channels_last_(is_channels_last) { + } + Status GenerateShaderCode(ShaderHelper& shader) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"filter_dims", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"x_spatial", ProgramUniformVariableDataType::Uint32}, + {"x_channels", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + bool has_bias_; + bool is_channels_last_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 6d6fedb3c9812..843d925ed6638 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -686,6 +686,8 @@ TEST(ConvFp16Test, Conv2D_AutoPad2) { TestConvFp16Op(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +// TODO: Enable Conv3D fp16 tests for WebGPU when the test infrastructure supports +// conditionally skipping based on device capabilities (e.g., wgpu::FeatureName::ShaderF16). TEST(ConvFp16Test, Conv3D_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index a1e5837a078a4..cda666b8e35f3 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -725,7 +725,7 @@ TEST(ConvTest, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {-0.43337246775627136f, -0.48385289311408997f, -0.30954962968826294f, @@ -762,7 +762,7 @@ TEST(ConvTest, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {0.010772407054901123f, -0.43806642293930054f, 0.455391526222229f, -0.28657248616218567f, @@ -805,7 +805,7 @@ TEST(ConvTest, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {kWebGpuExecutionProvider} // excluded EPs + {} // excluded EPs }; vector X = {0.46796226501464844f, -0.4613912105560303f, 0.33512794971466064f, -0.4010460674762726f,