diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 584cea7698..f7f3dbd445 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -267,25 +267,41 @@ struct webgpu_command { #endif }; +struct webgpu_capabilities { + wgpu::Limits limits; + bool supports_subgroup_matrix = false; + wgpu::SubgroupMatrixConfig subgroup_matrix_config{}; + uint32_t subgroup_size = 0; + uint32_t max_subgroup_size = 0; +}; + +// Stores global webgpu members +struct webgpu_global_context { + wgpu::Instance instance; + wgpu::Adapter adapter; + wgpu::Device device; + // Shared buffer to move data from device to host + wgpu::Buffer get_tensor_staging_buf; + // Mutex to lock get_tensor_staging_buf + std::recursive_mutex mutex; + webgpu_capabilities capabilities; +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { - wgpu::Instance instance; - wgpu::Adapter adapter; - wgpu::Device device; - wgpu::Queue queue; - wgpu::Limits limits; + // Points to global instances owned by ggml_backend_webgpu_reg_context + webgpu_global_context * global_ctx; - uint32_t max_subgroup_size; + wgpu::Queue queue; + webgpu_buf_pool param_buf_pool; bool supports_subgroup_matrix = false; uint32_t sg_mat_m; uint32_t sg_mat_n; uint32_t sg_mat_k; - std::recursive_mutex mutex; - std::atomic_uint inflight_threads = 0; + std::atomic_uint inflight_threads = 0; - webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; pre_wgsl::Preprocessor p; @@ -326,9 +342,6 @@ struct webgpu_context_struct { size_t memset_bytes_per_thread; - // Staging buffer for reading data from the GPU - wgpu::Buffer get_tensor_staging_buf; - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; @@ -351,34 +364,40 @@ struct webgpu_context_struct { typedef std::shared_ptr webgpu_context; +// Metadata required for the ggml backend registration/discovery interface struct ggml_backend_webgpu_reg_context { - webgpu_context webgpu_ctx; - size_t device_count; - const char * name; + // Since the Instance is a global entrypoint into the WebGPU api, it lives here + webgpu_global_context wgpu_global_ctx; + size_t device_count; + const char * name; }; +// Per-device struct for the global logical device interface struct ggml_backend_webgpu_device_context { - webgpu_context webgpu_ctx; - std::string device_name; - std::string device_desc; + ggml_backend_webgpu_reg_context * reg_context; + std::string device_name; + std::string device_desc; }; +// Per-thread data required to actually run webgpu operations in a backend instance struct ggml_backend_webgpu_context { - webgpu_context webgpu_ctx; + std::once_flag init_once; std::string name; }; +// Per-thread data related to buffers struct ggml_backend_webgpu_buffer_context { - webgpu_context webgpu_ctx; - wgpu::Buffer buffer; - std::string label; + wgpu::Buffer buffer; + std::string label; - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : - webgpu_ctx(std::move(ctx)), + ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl) : buffer(std::move(buf)), label(std::move(lbl)) {} }; +// Forward declaration +static webgpu_context get_per_thread_webgpu_context(ggml_backend_dev_t dev); + /* WebGPU object initializations */ // Process a WGSL shader string, replacing tokens of the form {{KEY}} with @@ -453,12 +472,13 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct uint32_t inflight_threads = ctx->inflight_threads; uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); while (futures.size() >= inflight_max && futures.size() > 0) { - ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); + ctx->global_ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); } size_t i = 0; while (i < futures.size()) { - auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + auto waitStatus = + ctx->global_ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); switch (waitStatus) { case wgpu::WaitStatus::Success: futures.erase(futures.begin() + i); @@ -481,14 +501,14 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", - message.data); - } - }), - UINT64_MAX); + ctx->global_ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", + message.data); + } + }), + UINT64_MAX); } #ifdef GGML_WEBGPU_DEBUG @@ -496,7 +516,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. static void ggml_backend_webgpu_debug(webgpu_context & ctx) { - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); @@ -617,12 +637,12 @@ static webgpu_command ggml_backend_webgpu_build_multi( bind_group_desc.entryCount = entries.size(); bind_group_desc.entries = entries.data(); bind_group_desc.label = pipelines[i].name.c_str(); - bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); + bind_groups.push_back(ctx->global_ctx->device.CreateBindGroup(&bind_group_desc)); params_bufs_list.push_back(params_bufs); } - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); for (const auto & params_bufs : params_bufs_list) { encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); } @@ -772,12 +792,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { @@ -825,21 +845,22 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g const bool circular = ggml_get_op_params_i32(dst, 8) != 0; ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; - ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { .key = pipeline_key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->pad_pipelines.find(pipeline_key); if (it != ctx->pad_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->pad_pipelines.emplace(pipeline_key, pipeline); } @@ -907,21 +928,22 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, .vec4 = src->ne[0] % 4 == 0, .i64_idx = idx->type == GGML_TYPE_I64 }; - ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { .key = key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { + .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; // TODO: remove guard once pipeline caches are per-thread { - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->set_rows_pipelines.find(key); if (it != ctx->set_rows_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->set_rows_pipelines.emplace(key, pipeline); } @@ -1098,8 +1120,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension); + wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; @@ -1223,25 +1245,27 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .uses_logit_softcap = logit_softcap != 0.0f, }; - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; // TODO: remove guard once pipeline caches are per-thread { - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; + pipeline = it->second; } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .key = key, - .sg_mat_m = ctx->sg_mat_m, - .sg_mat_n = ctx->sg_mat_n, - .sg_mat_k = ctx->sg_mat_k, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->max_subgroup_size }; + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->sg_mat_m, + .sg_mat_n = ctx->sg_mat_n, + .sg_mat_k = ctx->sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->flash_attn_pipelines.emplace(key, pipeline); } @@ -1250,7 +1274,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_flash_attn_shader_decisions decisions = *static_cast(pipeline.context); - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -1264,21 +1287,22 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_unary_pipeline_key pipeline_key = { .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace }; - ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { .key = pipeline_key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->unary_pipelines.find(pipeline_key); if (it != ctx->unary_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->unary_pipelines.emplace(pipeline_key, pipeline); } @@ -1696,20 +1720,21 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = src->ne[0] % 4 == 0, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); if (it != ctx->argmax_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); } } @@ -1722,13 +1747,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr // ascending order is 0, descending order is 1 const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); - ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .order = order }; + ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .order = order + }; - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); webgpu_pipeline argsort_pipeline; auto it = ctx->argsort_pipelines.find(order); if (it != ctx->argsort_pipelines.end()) { @@ -1736,7 +1761,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); - argsort_pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + argsort_pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); argsort_pipeline.context = processed.decisions; ctx->argsort_pipelines.emplace(order, argsort_pipeline); } @@ -1751,7 +1777,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); argsort_merge_pipeline = - ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); argsort_merge_pipeline.context = processed.decisions; ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); } @@ -1780,9 +1806,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const bool start_in_tmp = (merge_passes % 2) == 1; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); - const size_t tmp_offset = ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->limits.minStorageBufferOffsetAlignment); + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); + const size_t tmp_offset = + ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); const size_t dst_binding_size = ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT); @@ -1813,10 +1840,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr }; const uint32_t total_wg_init = npr * nrows; - const uint32_t max_wg = ctx->limits.maxComputeWorkgroupsPerDimension; - const uint32_t wg_x_init = std::min(total_wg_init, max_wg); - const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); - std::vector init_entries = { + const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + const uint32_t wg_x_init = std::min(total_wg_init, max_wg); + const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + std::vector init_entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), @@ -1912,19 +1939,20 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = false, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; // TODO: remove guard once pipeline caches are per-thread { - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->cumsum_pipelines.find(1); if (it != ctx->cumsum_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->cumsum_pipelines.emplace(1, pipeline); } } @@ -1956,20 +1984,21 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = false, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->sum_rows_pipelines.find(1); if (it != ctx->sum_rows_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->sum_rows_pipelines.emplace(1, pipeline); } } @@ -2070,8 +2099,7 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); - ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); - webgpu_context ctx = backend_ctx->webgpu_ctx; + webgpu_context ctx = get_per_thread_webgpu_context(backend->device); WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); @@ -2090,7 +2118,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str if (commands.size() >= batch_size) { futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); // Process events and check for completed submissions - ctx->instance.ProcessEvents(); + ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx, futures, false); commands.clear(); } @@ -2150,7 +2178,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = get_per_thread_webgpu_context(buffer->buft->device); + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); @@ -2159,8 +2188,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); - WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -2170,7 +2199,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + webgpu_context webgpu_ctx = get_per_thread_webgpu_context(buffer->buft->device); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); @@ -2194,7 +2223,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, remaining_size); } else { // wait for WriteBuffer to complete - webgpu_ctx->instance.WaitAny( + webgpu_ctx->global_ctx->instance.WaitAny( webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { @@ -2216,8 +2245,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + webgpu_context webgpu_ctx = get_per_thread_webgpu_context(buffer->buft->device); + wgpu::Device device = webgpu_ctx->global_ctx->device; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; @@ -2227,42 +2256,46 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, final_size = size + (4 - (size % 4)); } - std::lock_guard lock(webgpu_ctx->mutex); + std::lock_guard lock(webgpu_ctx->global_ctx->mutex); - if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + if (webgpu_ctx->global_ctx->get_tensor_staging_buf == nullptr || + webgpu_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) { // Create a new staging buffer if it doesn't exist or is too small - if (webgpu_ctx->get_tensor_staging_buf) { - webgpu_ctx->get_tensor_staging_buf.Destroy(); + if (webgpu_ctx->global_ctx->get_tensor_staging_buf) { + webgpu_ctx->global_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, + ggml_webgpu_create_buffer(device, webgpu_ctx->global_ctx->get_tensor_staging_buf, final_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); + encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->global_ctx->get_tensor_staging_buf, 0, + final_size); wgpu::CommandBuffer commands = encoder.Finish(); // Submit the command buffer to the queue webgpu_ctx->queue.Submit(1, &commands); // Map the staging buffer to read the data - ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size); + ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->global_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, + final_size); // Must specify size here since the staging buffer might be larger than the tensor size - const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); + const void * mapped_range = webgpu_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); - webgpu_ctx->get_tensor_staging_buf.Unmap(); + webgpu_ctx->global_ctx->get_tensor_staging_buf.Unmap(); WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); WEBGPU_CPU_PROFILE_TOTAL_START(clear); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); - WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = get_per_thread_webgpu_context(buffer->buft->device); + ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, webgpu_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -2292,28 +2325,28 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b int buffer_id = buffer_count++; std::string buf_name = "tensor_buf" + std::to_string(buffer_id); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes"); - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + + webgpu_context ctx = get_per_thread_webgpu_context(buft->device); wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), + ggml_webgpu_create_buffer(ctx->global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, buf_name.c_str()); - ggml_backend_webgpu_buffer_context * buf_ctx = - new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name); + ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(buf, buf_name); return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment; + webgpu_context ctx = get_per_thread_webgpu_context(buft->device); + return ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; + webgpu_context ctx = get_per_thread_webgpu_context(buft->device); + return ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize; } static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, @@ -2322,16 +2355,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer size_t res = ggml_nbytes(tensor); switch (tensor->op) { case GGML_OP_ARGSORT: - res = ROUNDUP_POW2(res * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, - WEBGPU_STORAGE_BUF_BINDING_MULT); + res = ROUNDUP_POW2( + res * 2 + ctx->reg_context->wgpu_global_ctx.capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); break; case GGML_OP_TOP_K: { const ggml_tensor * src0 = tensor->src[0]; if (src0) { const size_t full = sizeof(int32_t) * ggml_nelements(src0); - res = ROUNDUP_POW2(full * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, - WEBGPU_STORAGE_BUF_BINDING_MULT); + res = ROUNDUP_POW2( + full * 2 + + ctx->reg_context->wgpu_global_ctx.capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -2356,10 +2392,11 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_ } static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); + webgpu_context webgpu_ctx = get_per_thread_webgpu_context(dev); + // TODO: for now, return maxBufferSize as both free and total memory // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates. - uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize; + uint64_t max_buffer_size = webgpu_ctx->global_ctx->capabilities.limits.maxBufferSize; // If we're on a 32-bit system, clamp to UINTPTR_MAX #if UINTPTR_MAX < UINT64_MAX uint64_t max_ptr_size = static_cast(UINTPTR_MAX); @@ -2404,64 +2441,67 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = + WEBGPU_MAX_WG_SIZE * webgpu_ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads); + webgpu_ctx->memset_bytes_per_thread = + CEIL_DIV(webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = webgpu_ctx->memset_bytes_per_thread; - webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = webgpu_ctx->memset_bytes_per_thread; + webgpu_ctx->memset_pipelines[0] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_memset, "memset", constants); } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { // Q4/Q5/Q8 classic quantizations webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); // K-quantizations webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); // IQ quantizations (2-, 3-, 4-bit variants) webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); // 1-bit and 4-bit IQ variants webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); std::string proc_mul_mat_f32_f32; std::string proc_mul_mat_f32_f32_vec; @@ -2476,15 +2516,19 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = + std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = + std::to_string(webgpu_ctx->global_ctx->capabilities.subgroup_matrix_config.M); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = + std::to_string(webgpu_ctx->global_ctx->capabilities.subgroup_matrix_config.N); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = + std::to_string(webgpu_ctx->global_ctx->capabilities.subgroup_matrix_config.K); proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); proc_mul_mat_f32_f32_vec = @@ -2522,21 +2566,21 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #endif webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2547,171 +2591,171 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants); webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants); webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants); webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants); webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants); webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants); webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants); webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants); webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); + webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { @@ -2719,68 +2763,68 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { // REGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); // GEGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); // SWIGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); // SWIGLU_OAI webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); // GEGLU_ERF webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); // GEGLU_QUICK webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); + webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { @@ -2788,56 +2832,247 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { // f32 (no mask) webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); + webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); + webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[0][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, + "soft_max_f32_mask_f32_sink_inplace", constants); // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); + webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + webgpu_ctx->soft_max_pipelines[1][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, + "soft_max_f32_mask_f16_sink_inplace", constants); +} + +static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + options.nextInChain = &adapterTogglesDesc; +#endif + + ctx->wgpu_global_ctx.instance.WaitAny( + ctx->wgpu_global_ctx.instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->wgpu_global_ctx.adapter = std::move(adapter); + }), + UINT64_MAX); + GGML_ASSERT(ctx->wgpu_global_ctx.adapter != nullptr); + + ctx->wgpu_global_ctx.adapter.GetLimits(&ctx->wgpu_global_ctx.capabilities.limits); + + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->wgpu_global_ctx.adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } +#endif + ctx->wgpu_global_ctx.adapter.GetInfo(&info); + + wgpu::SupportedFeatures features; + ctx->wgpu_global_ctx.adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->wgpu_global_ctx.adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + +#ifndef __EMSCRIPTEN__ + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + if (ctx->wgpu_global_ctx.adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->wgpu_global_ctx.capabilities.subgroup_matrix_config = config; + valid_subgroup_matrix_config = true; + break; + } + } + } + ctx->wgpu_global_ctx.capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + ctx->wgpu_global_ctx.capabilities.max_subgroup_size = info.subgroupMaxSize; + + // Initialize device + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); + if (ctx->wgpu_global_ctx.capabilities.supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &ctx->wgpu_global_ctx.capabilities.limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_UNUSED(reason); + GGML_UNUSED(message); + //TODO: uncomment once proper free logic is in place + //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + //std::string(message).c_str()); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); + }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + + ctx->wgpu_global_ctx.instance.WaitAny( + ctx->wgpu_global_ctx.adapter.RequestDevice( + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); + return; + } + ctx->wgpu_global_ctx.device = std::move(device); + }), + UINT64_MAX); + GGML_ASSERT(ctx->wgpu_global_ctx.device != nullptr); + + GGML_LOG_INFO( + "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " + "device_desc: %s\n", + info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, + std::string(info.device).c_str(), std::string(info.description).c_str()); + return true; +} + +// Fetch the webgpu_context for this thread (indexed by device for multiple devices running per thread) +// If it hasn't been created yet, then initiailize one. +static webgpu_context get_per_thread_webgpu_context(ggml_backend_dev_t dev) { + thread_local std::unordered_map map; + + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; + if (auto it = map.find(dev_ctx); it != map.end()) { + return it->second; + } + + webgpu_context webgpu_ctx = std::make_shared(); + webgpu_ctx->global_ctx = &(dev_ctx->reg_context->wgpu_global_ctx); + webgpu_ctx->queue = webgpu_ctx->global_ctx->device.GetQueue(); + webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); +#ifdef GGML_WEBGPU_GPU_PROFILE + // Initialize buffer pool for timestamp queries, used for profiling + webgpu_ctx->timestamp_query_buf_pool.init(webgpu_ctx->wgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); +#endif + webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + + ggml_webgpu_init_memset_pipeline(webgpu_ctx); + ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); + ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); + ggml_webgpu_init_cpy_pipeline(webgpu_ctx); + ggml_webgpu_init_add_pipeline(webgpu_ctx); + ggml_webgpu_init_sub_pipeline(webgpu_ctx); + ggml_webgpu_init_mul_pipeline(webgpu_ctx); + ggml_webgpu_init_div_pipeline(webgpu_ctx); + ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); + ggml_webgpu_init_rope_pipeline(webgpu_ctx); + ggml_webgpu_init_glu_pipeline(webgpu_ctx); + ggml_webgpu_init_scale_pipeline(webgpu_ctx); + ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); +#endif + + map.try_emplace(dev_ctx, webgpu_ctx); + return webgpu_ctx; } -// TODO: move most initialization logic here static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); - ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - static ggml_backend_webgpu_context backend_ctx; - backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; - backend_ctx.webgpu_ctx = webgpu_ctx; + auto * backend_ctx = new ggml_backend_webgpu_context(); + backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; // See GGML Backend Interface section - static ggml_backend backend = { + auto * backend = new ggml_backend(); + *backend = { /* .guid = */ ggml_backend_webgpu_guid(), /* .interface = */ ggml_backend_webgpu_i, /* .device = */ dev, - /* .context = */ &backend_ctx, + /* .context = */ backend_ctx, }; - return &backend; + return backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -2854,7 +3089,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm }, /* .device = */ dev, - /* .context = */ NULL, + /* .context = */ + NULL }; return &ggml_backend_webgpu_buffer_type; @@ -2893,18 +3129,18 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); - - webgpu_context webgpu_ctx = ctx->webgpu_ctx; + webgpu_context webgpu_ctx = get_per_thread_webgpu_context(dev); ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; ggml_tensor * src2 = op->src[2]; // on smaller devices (or CI), tensors may be larger than the max storage buffer size - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { return false; } @@ -2988,7 +3224,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + size_t limit_bytes = webgpu_ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; @@ -3099,10 +3335,13 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const default: break; } - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src2 != nullptr && + ggml_nbytes(src2) > webgpu_ctx->global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { supports_op = false; WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); } @@ -3156,6 +3395,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { // TODO: Does this need to be thread safe? Is it only called once? // TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now + static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); @@ -3164,188 +3404,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); - webgpu_context ctx = reg_ctx->webgpu_ctx; - - wgpu::RequestAdapterOptions options = {}; - -#ifndef __EMSCRIPTEN__ - // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 - const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; - wgpu::DawnTogglesDescriptor adapterTogglesDesc; - adapterTogglesDesc.enabledToggles = adapterEnabledToggles; - adapterTogglesDesc.enabledToggleCount = 2; - options.nextInChain = &adapterTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->adapter = std::move(adapter); - }), - UINT64_MAX); - GGML_ASSERT(ctx->adapter != nullptr); - - ctx->adapter.GetLimits(&ctx->limits); - - wgpu::AdapterInfo info{}; -#ifndef __EMSCRIPTEN__ - wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - info.nextInChain = &subgroup_matrix_configs; - } -#endif - ctx->adapter.GetInfo(&info); - - wgpu::SupportedFeatures features; - ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); - -#ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now - bool valid_subgroup_matrix_config = false; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && - config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->sg_mat_m = config.M; - ctx->sg_mat_n = config.N; - ctx->sg_mat_k = config.K; - valid_subgroup_matrix_config = true; - break; - } - } - } - - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; -#endif - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->max_subgroup_size = info.subgroupMaxSize; - - // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16 }; - -#ifndef __EMSCRIPTEN__ - required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - if (ctx->supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); - required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - } -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - required_features.push_back(wgpu::FeatureName::TimestampQuery); -#endif - - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &ctx->limits; - dev_desc.requiredFeatures = required_features.data(); - dev_desc.requiredFeatureCount = required_features.size(); - dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - //std::string(message).c_str()); - }); - dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); - }); - -#ifndef __EMSCRIPTEN__ - // Enable Dawn-specific toggles to increase native performance - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - - dev_desc.nextInChain = &deviceTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, - [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", - std::string(message).c_str()); - return; - } - ctx->device = std::move(device); - }), - UINT64_MAX); - GGML_ASSERT(ctx->device != nullptr); - - // Initialize (compute) queue - ctx->queue = ctx->device.GetQueue(); - - // Create buffer pool for shader parameters - ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries (profiling) - ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, - WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - - ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - - ggml_webgpu_init_memset_pipeline(ctx); - ggml_webgpu_init_mul_mat_pipeline(ctx); - ggml_webgpu_init_get_rows_pipeline(ctx); - ggml_webgpu_init_cpy_pipeline(ctx); - ggml_webgpu_init_add_pipeline(ctx); - ggml_webgpu_init_sub_pipeline(ctx); - ggml_webgpu_init_mul_pipeline(ctx); - ggml_webgpu_init_div_pipeline(ctx); - ggml_webgpu_init_rms_norm_pipeline(ctx); - ggml_webgpu_init_rope_pipeline(ctx); - ggml_webgpu_init_glu_pipeline(ctx); - ggml_webgpu_init_scale_pipeline(ctx); - ggml_webgpu_init_soft_max_pipeline(ctx); - -#ifdef GGML_WEBGPU_DEBUG - // Initialize debug buffers - ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); - ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); -#endif + create_webgpu_device(reg_ctx); static ggml_backend_webgpu_device_context device_ctx; - device_ctx.webgpu_ctx = ctx; device_ctx.device_name = GGML_WEBGPU_NAME; - device_ctx.device_desc = info.description; - - GGML_LOG_INFO( - "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " - "device_desc: %s\n", - info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, - std::string(info.device).c_str(), std::string(info.description).c_str()); + // TODO: we get the device description when actually initializing the webgpu device, so move this to backend context? + device_ctx.device_desc = GGML_WEBGPU_NAME; + device_ctx.reg_context = reg_ctx; // See GGML Backend Device Interface section static ggml_backend_device device = { @@ -3370,10 +3435,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - webgpu_context webgpu_ctx = std::make_shared(); - static ggml_backend_webgpu_reg_context ctx; - ctx.webgpu_ctx = webgpu_ctx; ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; @@ -3390,7 +3452,8 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx.wgpu_global_ctx.instance = std::move(inst); #ifdef __EMSCRIPTEN__ if (webgpu_ctx->instance == nullptr) { @@ -3398,7 +3461,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { return nullptr; } #endif - GGML_ASSERT(webgpu_ctx->instance != nullptr); + GGML_ASSERT(ctx.wgpu_global_ctx.instance != nullptr); static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,