Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions onnxruntime/core/providers/webgpu/program_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "core/providers/webgpu/program_manager.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_context.h"

namespace onnxruntime {
namespace webgpu {
Expand All @@ -22,7 +23,7 @@ ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeli
Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const {
ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")");

auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension;
auto limit_per_dimension = webgpu_context_.DeviceLimits().maxComputeWorkgroupsPerDimension;
if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) {
double size = static_cast<double>(x) * static_cast<double>(y) * static_cast<double>(z);
double dispatch_avg = std::ceil(std::sqrt(size));
Expand All @@ -39,7 +40,7 @@ Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint
}

Status ProgramManager::CalculateSegmentsForInputsAndOutputs(ProgramBase& program) {
const uint64_t maxStorageBufferBindingSize = limits_.maxStorageBufferBindingSize;
const uint64_t maxStorageBufferBindingSize = webgpu_context_.DeviceLimits().maxStorageBufferBindingSize;

// Inputs
for (size_t i = 0; i < program.Inputs().size(); ++i) {
Expand Down Expand Up @@ -70,10 +71,11 @@ Status ProgramManager::Build(const ProgramBase& program,
uint32_t normalized_dispatch_z,
wgpu::ComputePipeline& compute_pipeline,
std::vector<int>& shape_uniform_ranks) const {
auto& device = webgpu_context_.Device();
ShaderHelper shader_helper{program,
program_metadata,
device_,
limits_,
device,
webgpu_context_.DeviceLimits(),
normalized_dispatch_x,
normalized_dispatch_y,
normalized_dispatch_z};
Expand Down Expand Up @@ -110,7 +112,7 @@ Status ProgramManager::Build(const ProgramBase& program,
wgpu::ShaderModuleDescriptor descriptor{};
descriptor.nextInChain = &wgsl_source;

auto shader_module = device_.CreateShaderModule(&descriptor);
auto shader_module = device.CreateShaderModule(&descriptor);

// TODO: a new cache hierarchy for constants.
//
Expand Down Expand Up @@ -186,9 +188,26 @@ Status ProgramManager::Build(const ProgramBase& program,
pipeline_descriptor.label = program.Name().c_str();
#endif

compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor);

return Status();
struct CreateComputePipelineContext {
wgpu::ComputePipeline& pipeline;
Status status;
} create_pipeline_context{compute_pipeline, {}};

ORT_RETURN_IF_ERROR(
webgpu_context_.Wait(
device.CreateComputePipelineAsync(
&pipeline_descriptor,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::CreatePipelineAsyncStatus status, wgpu::ComputePipeline pipeline, wgpu::StringView message, CreateComputePipelineContext* context) {
if (status == wgpu::CreatePipelineAsyncStatus::Success) {
context->pipeline = std::move(pipeline);
} else {
context->status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create a WebGPU compute pipeline: ", std::string_view{message});
}
},
&create_pipeline_context)));

return create_pipeline_context.status;
}

const ProgramArtifact* ProgramManager::Get(const std::string& key) const {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webgpu/program_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
class Tensor;

namespace webgpu {
class WebGpuContext;

class ProgramArtifact {
public:
Expand All @@ -34,7 +35,7 @@

class ProgramManager {
public:
ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {}
ProgramManager(WebGpuContext& webgpu_context) : webgpu_context_(webgpu_context) {}

Check warning on line 38 in onnxruntime/core/providers/webgpu/program_manager.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/program_manager.h:38: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]

Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const;
Status CalculateSegmentsForInputsAndOutputs(ProgramBase& program);
Expand All @@ -54,8 +55,7 @@

private:
std::unordered_map<std::string, ProgramArtifact> programs_;
const wgpu::Device& device_;
const wgpu::Limits& limits_;
WebGpuContext& webgpu_context_;
};

} // namespace webgpu
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
BufferCacheMode::Disabled);

// create program manager
program_mgr_ = std::make_unique<ProgramManager>(Device(), DeviceLimits());
program_mgr_ = std::make_unique<ProgramManager>(*this);

// set query type
#if !defined(__wasm__)
Expand Down
Loading