diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 52c705abb1003..26b08159a3b84 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -468,7 +468,7 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const parameters.head_size_ == parameters.v_head_size_ && bias == nullptr && parameters.sequence_length_ > 1 && - context.Device().HasFeature(wgpu::FeatureName::Subgroups) && + context.HasFeature(wgpu::FeatureName::Subgroups) && present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 && present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0; } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 6d2370db853ee..ad8319aeff1ad 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -444,7 +444,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, bool has_zero_points) { // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 - bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && + bool use_dp4a = context.HasFeature(wgpu::FeatureName::Subgroups) && context.AdapterInfo().backendType != wgpu::BackendType::Metal; return (accuracy_level == 4 && block_size % 32 == 0 && batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 && diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index be105a0fd4374..e8a15b8d47a56 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -572,7 +572,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; - const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); + const bool has_subgroup = context.HasFeature(wgpu::FeatureName::Subgroups); const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast(components_b), has_zero_points, use_subgroup}; if (M > kMinMForTileOptimization && block_size == 32) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index cb024d2a758a9..b1dce049214eb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -203,7 +203,7 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont uint32_t K, bool has_zero_points) { #if !defined(__wasm__) - const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + const bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); #else const bool has_subgroup_matrix = false; #endif diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 23fa10a0d5489..3117208c7be7d 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -37,8 +37,8 @@ class ComputeContext { inline const wgpu::Limits& DeviceLimits() const { return webgpu_context_.DeviceLimits(); } - inline const wgpu::Device& Device() const { - return webgpu_context_.Device(); + inline bool HasFeature(wgpu::FeatureName feature) const { + return webgpu_context_.DeviceHasFeature(feature); } // diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 0471e08c4a215..955b54e873261 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -140,6 +140,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits ORT_ENFORCE(Device().GetLimits(&device_limits_)); + // cache device features + wgpu::SupportedFeatures supported_features; + Device().GetFeatures(&supported_features); + for (size_t i = 0; i < supported_features.featureCount; i++) { + device_features_.insert(supported_features.features[i]); + } #if !defined(__wasm__) supports_buffer_map_extended_usages_ = device_.HasFeature(wgpu::FeatureName::BufferMapExtendedUsages); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 11e388a22e03f..2f044400afee2 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -80,6 +80,7 @@ class WebGpuContext final { const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } const wgpu::Limits& DeviceLimits() const { return device_limits_; } + bool DeviceHasFeature(wgpu::FeatureName feature) const { return device_features_.find(feature) != device_features_.end(); } const wgpu::CommandEncoder& GetCommandEncoder() { if (!current_command_encoder_) { @@ -208,6 +209,7 @@ class WebGpuContext final { wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; + std::unordered_set device_features_; wgpu::CommandEncoder current_command_encoder_; wgpu::ComputePassEncoder current_compute_pass_encoder_;