diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc index a0213f63494d3..9548386ded06c 100644 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc @@ -91,6 +91,14 @@ REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12); REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17); REGISTER_REDUCE_KERNEL(ReduceLogSumExp, 18); +REGISTER_REDUCE_VERSIONED_KERNEL(ArgMax, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ArgMax, 11, 12); +REGISTER_REDUCE_KERNEL(ArgMax, 13); + +REGISTER_REDUCE_VERSIONED_KERNEL(ArgMin, 1, 10); +REGISTER_REDUCE_VERSIONED_KERNEL(ArgMin, 11, 12); +REGISTER_REDUCE_KERNEL(ArgMin, 13); + Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); if (is_input_empty_) { @@ -114,6 +122,9 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const { std::stringstream ss; std::string index = "i" + std::to_string(i); ss << "for (var " << index << " : u32 = 0; " << index << " < " << input.IndicesGet("uniforms.input_shape", i) << "; " << index << "++) {\n"; + if (loop_body.find("last_index") != std::string::npos) { + ss << "let last_index = " + index + ";\n"; + } ss << input.IndicesSet("input_indices", i, index) << ";\n"; ss << loop_body << "\n"; ss << "}\n"; @@ -337,5 +348,25 @@ ReduceOpSpecificCode ReduceLogSumExp::GetOpSpecificCode(const Tensor* input_tens return code; } +ReduceOpSpecificCode ArgMin::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string op = (select_last_index_) ? "<=" : "<"; + std::string loop_header = "var best_element = first_element; var best_index = u32(0);"; + std::string loop_body = "if (current_element " + op + " best_element) { best_element = current_element; best_index = last_index; };"; + std::string loop_footer = "let output_value = output_value_t(best_index);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} + +ReduceOpSpecificCode ArgMax::GetOpSpecificCode(const Tensor* input_tensor) const { + ORT_UNUSED_PARAMETER(input_tensor); + std::string op = (select_last_index_) ? ">=" : ">"; + std::string loop_header = "var best_element = first_element; var best_index = u32(0);"; + std::string loop_body = "if (current_element " + op + " best_element) { best_element = current_element; best_index = last_index; };"; + std::string loop_footer = "let output_value = output_value_t(best_index);"; + ReduceOpSpecificCode code({loop_header, loop_body, loop_footer}); + return code; +} + } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h index 291d931f41c05..70ae6d3c71eb9 100644 --- a/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/webgpu/reduction/reduction_ops.h @@ -119,5 +119,17 @@ class ReduceLogSumExp final : public ReduceKernel { ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; }; +class ArgMin final : public ReduceKernel { + public: + ArgMin(const OpKernelInfo& info) : ReduceKernel(info, "ArgMin", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + +class ArgMax final : public ReduceKernel { + public: + ArgMax(const OpKernelInfo& info) : ReduceKernel(info, "ArgMax", true) {} + ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override; +}; + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index dfb2e4b6ce665..6f81bead5e5b1 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -297,12 +297,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ArgMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ArgMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ArgMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ArgMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ArgMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax); @@ -624,13 +624,13 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,