diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc index 136a1ba9776a0..63516d7e48d52 100644 --- a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc @@ -11,7 +11,43 @@ namespace webgpu { ONNX_OPERATOR_KERNEL_EX( Squeeze, kOnnxDomain, - 13, + 24, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 23, 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, 20, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 0f7607ac1dbfe..135782ad577c4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -246,7 +246,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); @@ -529,7 +532,10 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,