Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const
parameters.head_size_ == parameters.v_head_size_ &&
bias == nullptr &&
parameters.sequence_length_ > 1 &&
context.Device().HasFeature(wgpu::FeatureName::Subgroups) &&
context.HasFeature(wgpu::FeatureName::Subgroups) &&
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
present_value->SizeInBytes() > 0 && parameters.head_size_ % 4 == 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
bool has_zero_points) {
// macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support.
// https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226
bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) &&
bool use_dp4a = context.HasFeature(wgpu::FeatureName::Subgroups) &&
context.AdapterInfo().backendType != wgpu::BackendType::Metal;
return (accuracy_level == 4 && block_size % 32 == 0 &&
batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
constexpr uint32_t output_number = 1;
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups);
const bool has_subgroup = context.HasFeature(wgpu::FeatureName::Subgroups);
const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32;
MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast<int>(components_b), has_zero_points, use_subgroup};
if (M > kMinMForTileOptimization && block_size == 32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont
uint32_t K,
bool has_zero_points) {
#if !defined(__wasm__)
const bool has_subgroup_matrix = context.Device().HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
const bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
#else
const bool has_subgroup_matrix = false;
#endif
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class ComputeContext {
inline const wgpu::Limits& DeviceLimits() const {
return webgpu_context_.DeviceLimits();
}
inline const wgpu::Device& Device() const {
return webgpu_context_.Device();
inline bool HasFeature(wgpu::FeatureName feature) const {
return webgpu_context_.DeviceHasFeature(feature);
}

//
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_));
// cache device limits
ORT_ENFORCE(Device().GetLimits(&device_limits_));
// cache device features
wgpu::SupportedFeatures supported_features;
Device().GetFeatures(&supported_features);
for (size_t i = 0; i < supported_features.featureCount; i++) {
device_features_.insert(supported_features.features[i]);
}

#if !defined(__wasm__)
supports_buffer_map_extended_usages_ = device_.HasFeature(wgpu::FeatureName::BufferMapExtendedUsages);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@

const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
const wgpu::Limits& DeviceLimits() const { return device_limits_; }
bool DeviceHasFeature(wgpu::FeatureName feature) const { return device_features_.find(feature) != device_features_.end(); }

const wgpu::CommandEncoder& GetCommandEncoder() {
if (!current_command_encoder_) {
Expand Down Expand Up @@ -208,6 +209,7 @@

wgpu::AdapterInfo adapter_info_;
wgpu::Limits device_limits_;
std::unordered_set<wgpu::FeatureName> device_features_;

Check warning on line 212 in onnxruntime/core/providers/webgpu/webgpu_context.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_context.h:212: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

wgpu::CommandEncoder current_command_encoder_;
wgpu::ComputePassEncoder current_compute_pass_encoder_;
Expand Down
Loading