-
Notifications
You must be signed in to change notification settings - Fork 254
Enable graph capture for webgpu #1848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Is this a result of internal discussion on the last comment from #1720 or some other approach? Also have you tested your webgpu implementation on a UMA device (i.e arm64/apple silicon)? If so and you witness the same issue I have then we can discuss. I have met with Arm reps on the kleidiai team regarding the issue and may be able to give insights. |
I think the preferred solution is Solution two in #1720 after sync with some guys. And currently, onnxruntime has exposed
So what kind of issue do you meet in Arm? And what's your insights? |
I have seen an issue whereby calls to CopyDeviceToCpu cause major slowdown as well as visual stuttering with under-utilisation of the hardware. This is the case for my own solution (similar to your solution three in #1720) as well as your implementation in that PR for solution one (with USE_WEBGPU=OFF). Essentially the execution of the model is very fast, whilst the first CopyDeviceToCpu call takes longer than the model itself. This happens whether it's when transferring the logits (~262KB for Llama-3B) or the decoded sequence (~16KB for a 2000 int64 token sequence)). I have a version that decodes entirely on GPU with WebGPU kernels and only calls CopyDeviceToCpu once at the end of generation and I still see the same behaviour. See the following trace to illustrate (using ORT_GenAI's internal tracing system from #1524) This only occurs on unified memory systems, which are mostly Arm in the market right now (Qualcomm, Apple Silicon, etc.) but the same behaviour appears when we tested on the AMD Zen 2 unified memory APU. When discussing with Arm reps that worked on the KleidiAI EP (which is entirely CPU based), they mentioned that it might be to do with the synchronisation between CPU and GPU, but admitted that if inference on Arm chips could be done on GPU they will likely see better performance and energy efficiency compared to their CPU kleidiAI approach. However, I have seen slowdown due to this on some lower-end dedicated GPU devices such as the AMD Radeon RX 5700 XT. Have you tested at all with any unified memory devices? If you would like anymore information on my findings let me know. We're actively trying to resolve this issue. |
My understanding is that the CopyDeviceToCpu call might also includes partial model inference time. How do you ensure that all previous GPU work is completed before measuring the CopyDeviceToCpu time? If you have access to the raw Dawn/WebGPU API in your code, you might want to use the https://www.w3.org/TR/webgpu/#dom-gpuqueue-onsubmittedworkdone call before proceeding with CopyDeviceToCpu. This method can confirm that all prior GPU tasks are finished. Please share your findings. Thank you. |
This reverts commit 60619b2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables graph capture support for WebGPU by implementing device-CPU memory operations using ONNX Runtime's new CopyTensors API. It also upgrades the ONNX Runtime dependency from version 1.22.0 to 1.23.0 across the entire codebase.
Key changes:
- Implements
CopyDeviceToCpu,CopyCpuToDevice,CopyFrom, andZeromethods for WebGPU using the newCopyTensorsAPI - Adds graph capture detection for WebGPU in config processing and model initialization
- Updates ONNX Runtime version to 1.23.0 across all build configurations and test requirements
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| src/webgpu/interface.cpp | Implements WebGPU memory operations with CopyTensors API; manages CPU/device memory transfers with fallback for offset copies |
| src/models/onnxruntime_api.h | Adds CopyTensors method declaration and OrtSyncStream wrapper for async operations |
| src/models/onnxruntime_inline.h | Implements CopyTensors wrapper with validation |
| src/models/model.cpp | Adds WebGPU graph capture check to determine device memory usage for inputs |
| src/config.cpp | Implements graph capture detection for WebGPU provider based on enableGraphCapture option |
| CMakeLists.txt | Suppresses MSVC warning C4819 for non-representable characters |
| test/python/*/ort/requirements.txt | Updates ONNX Runtime test dependencies to 1.23.0 |
| cmake/ortlib.cmake | Updates default ONNX Runtime version to 1.23.0 for all execution providers |
| examples/slm_engine/build_scripts/build_deps.py | Updates ORT version references to 1.23.0 |
| .pipelines/nuget-publishing.yml | Updates pipeline ORT version parameters to 1.23.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Co-authored-by: Copilot <[email protected]>
This PR enables the graph capture for webgpu. It implements CopyDeviceToCpu\CopyCpuToDevice\CopyFrom\Zero functions using the new `CopyTensors` API. The ort part needs to apply this PR [#26450](microsoft/onnxruntime#26450) to make it work for webgpu. Below things will be implemented in following-up PRs to get the full performance gain for graph capture (The original one is #1720). 1. Support UpdateAttentionMask, UpdatePositionIds, and Cast to keep the whole pipeline on gpu. 2. Optimize CopyFrom with offsets --------- Co-authored-by: Copilot <[email protected]>

This PR enables the graph capture for webgpu. It implements CopyDeviceToCpu\CopyCpuToDevice\CopyFrom\Zero functions using the new
CopyTensorsAPI.The ort part needs to apply this PR #26450 to make it work for webgpu.
Below things will be implemented in following-up PRs to get the full performance gain for graph capture (The original one is #1720).