Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
54fe58c
First incorrect draft impl
adrianlizarraga Apr 2, 2026
b25bb09
Add IExecutionProvider::GetGraphCaptureNodeAssignmentPolicy() to remo…
adrianlizarraga Apr 2, 2026
3146b85
Minor update to docs. Stills need improvement
adrianlizarraga Apr 2, 2026
023c0c0
Add onnxruntime:: namespace qualifier for DML use of enum
adrianlizarraga Apr 3, 2026
bd71d7d
Merge branch 'main' into adrianl/PluginEp_CudaGraphCaptureReplay
adrianlizarraga Apr 3, 2026
857f0b4
Implement new graph capture/replay C APIs in webgpu's OrtEp implement…
adrianlizarraga Apr 3, 2026
0741b28
Add stubs with todos to cuda's plugin EP
adrianlizarraga Apr 3, 2026
f08d05b
Use only one enum for the graph capture node assignment policy
adrianlizarraga Apr 3, 2026
211bb7e
Update C doc comments
adrianlizarraga Apr 3, 2026
91fb965
First attempt at a webgpu plugin EP test that uses graph capture
adrianlizarraga Apr 3, 2026
c38ba6c
Add ifdef guards for webgpu plugin EP test
adrianlizarraga Apr 6, 2026
61b8ff6
Review comments: Update NvTensorRTRTX EP to return false from IExecut…
adrianlizarraga Apr 6, 2026
977dea1
Add InferenceSession::RunImpl() that takes in the recursive run depth
adrianlizarraga Apr 6, 2026
8645b8d
Ensure at least one compute node is assigned to the graph capturing EP
adrianlizarraga Apr 6, 2026
b158299
Add ORT_ENFORCE
adrianlizarraga Apr 6, 2026
0bc9cdc
Merge branch 'main' into adrianl/PluginEp_CudaGraphCaptureReplay
adrianlizarraga Apr 6, 2026
e289ae2
Fix graph capture/replay test for the webgpu plugin EP (runs locally)
adrianlizarraga Apr 6, 2026
ff09625
CXX API: add Ort::Env::CopyTensor() to copy only one tensor
adrianlizarraga Apr 6, 2026
0718dae
Update test EP registration to use correct ep lib name for various pl…
adrianlizarraga Apr 7, 2026
156a3cc
Update docs regarding the max number of attempts to call IsGraphCaptured
adrianlizarraga Apr 7, 2026
a7e491e
Merge branch 'main' into adrianl/PluginEp_CudaGraphCaptureReplay
adrianlizarraga Apr 9, 2026
2059886
Adjust test ifdef for webgpu EP test that is not really the plugin EP…
adrianlizarraga Apr 9, 2026
9205f22
Address doc comments
adrianlizarraga Apr 9, 2026
5694f6c
Check that OrtEp::IsGraphCaptured and OrtEp::ReplayGraph are implemen…
adrianlizarraga Apr 9, 2026
cb66889
Address comments
adrianlizarraga Apr 9, 2026
91851ac
Sync IO binding inputs in test
adrianlizarraga Apr 9, 2026
ff929f3
Improved documentation for graph annotation ID
adrianlizarraga Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,15 @@ class IExecutionProvider {
return Status::OK();
}

/**
Get the node assignment validation policy for graph capture.
When graph capture is enabled, ORT validates that nodes are assigned to EPs
in a way compatible with graph capture. This tells ORT which policy to apply.
*/
virtual OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const {
return OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP;
}

/**
Called when session creation is complete
This provides an opportunity for execution providers to optionally synchronize and
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,10 @@ struct Env : detail::Base<OrtEnv> {
const std::vector<Value>& dst_tensors,
OrtSyncStream* stream) const; ///< Wraps OrtApi::CopyTensors

/// Wraps OrtApi::CopyTensors
/// Copies only one src tensor to another dst tensor.
Status CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const;

/// \brief Wraps OrtApi::SetPerSessionThreadPoolCallbacks
/// Stores work callbacks on the Env for per-session thread pools.
/// Only affects sessions created after this call. Does not affect global thread pools.
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,11 @@ inline Status Env::CopyTensors(const std::vector<Value>& src_tensors,
return Status(status);
}

inline Status Env::CopyTensor(const OrtValue* src_tensor, OrtValue* dst_tensor, OrtSyncStream* stream) const {
OrtStatus* status = GetApi().CopyTensors(p_, &src_tensor, &dst_tensor, stream, 1);
return Status(status);
}

inline UnownedAllocator Env::CreateSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type,
OrtAllocatorType allocator_type,
const OrtKeyValuePairs* allocator_options) {
Expand Down
112 changes: 112 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,23 @@ typedef enum OrtEpDataLayout {
OrtEpDataLayout_Default = OrtEpDataLayout_NCHW,
} OrtEpDataLayout;

/**
* \brief Node assignment policies for graph capture validation.
*
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is
* compatible with graph capture. An EP can specify which validation policy ORT should apply.
*
* \since Version 1.26.
*/
typedef enum OrtGraphCaptureNodeAssignmentPolicy {
Comment thread
tianleiwu marked this conversation as resolved.
/** All nodes in the main graph must be assigned to this EP. No CPU fallback is allowed. */
OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP = 0,

/** Compute nodes must be on this EP. CPU nodes are allowed for shape computation as long as
* no memory copy nodes exist. */
OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES = 1,
} OrtGraphCaptureNodeAssignmentPolicy;

/**
* \brief The OrtEp struct provides functions to implement for an execution provider.
* \since Version 1.22.
Expand Down Expand Up @@ -2346,6 +2363,101 @@ struct OrtEp {
*/
ORT_API2_STATUS(CreateProfiler, _In_ OrtEp* this_ptr,
_Outptr_result_maybenull_ OrtEpProfilerImpl** profiler);

/** \brief Indicate whether the graph capturing mode (e.g., CUDA graph) is enabled for the provider.
*
* Graph capture allows an EP to record a sequence of device (e.g., GPU) operations during an initial run and replay
* them on subsequent runs, bypassing per-kernel CPU launch overhead.
*
* Applications enable graph capture via EP-specific provider options (e.g., `enable_cuda_graph=1`
* for the CUDA EP). An EP should return true from this function if it has been configured to enable
* graph capture/replay.
*
* **ORT graph capture/replay summary:**
* During OrtSession initialization, ORT calls OrtEp::IsGraphCaptureEnabled() on each EP in the order specified during
* provider registration with the session. If an EP returns true, ORT validates that the graph is suitable for
* graph capture, and if so, caches the EP for graph capture during the next run. The graph validation ensures
* that there are no control flow nodes and that node-to-EP assignments are compatible with the policy specified
* by the EP via OrtEp::GetGraphCaptureNodeAssignmentPolicy().
* Note that an OrtSession only supports graph capture for one EP (i.e., the first EP to claim support).
Comment thread
adrianlizarraga marked this conversation as resolved.
*
* During the first call to OrtApi::Run() for the OrtSession, ORT performs multiple internal runs of the model
* until the EP indicates that the graph has been captured by returning `true` from `OrtEp::IsGraphCaptured()`.
* If the EP is unable to capture the graph within 8 runs, the call to OrtApi::Run() returns an error OrtStatus.
Comment thread
tianleiwu marked this conversation as resolved.
* Each internal run invokes `OrtEp::OnRunStart()`, normal execution, and `OrtEp::OnRunEnd()`. EPs should use
* these run callbacks to track the number of necessary warm-up runs and begin/end graph capture when ready.
*
* After successful graph capture, subsequent calls to OrtApi::Run() skip normal execution and ORT instead calls
* `OrtEp::ReplayGraph()` directly.
*
* Applications can capture and replay multiple graphs (e.g., one per distinct input shape) by setting the
* `"gpu_graph_id"` run config entry via `OrtApi::AddRunConfigEntry()` to different integer values. ORT passes
* the value as the `graph_annotation_id` parameter to `OrtEp::IsGraphCaptured()` and `OrtEp::ReplayGraph()`.
*
* \param[in] this_ptr The OrtEp instance.
* \return true if graph capture mode is enabled, false otherwise.
*
* \note Implementation of this function is optional. If set to NULL, ORT assumes graph capture is not enabled.
* \note If this function returns true, `OrtEp::IsGraphCaptured` and `OrtEp::ReplayGraph` must also be implemented.
* If either is NULL, ORT will log a warning and ignore this EP for graph capture.
*
* \since Version 1.26.
*/
ORT_API_T(bool, IsGraphCaptureEnabled, _In_ const OrtEp* this_ptr);

/** \brief Indicate whether a graph has been captured and instantiated.
*
* ORT calls this before each `Session::Run()`. If true, ORT calls `ReplayGraph()` instead of
* normal execution. After a run where this returns false, ORT automatically retries until it
* returns true (handling warm-up runs transparently).
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] graph_annotation_id Identifies which captured graph to query.
* Applications can set this value via `OrtApi::AddRunConfigEntry()` with the key `"gpu_graph_id"`.
* The default value is 0 when the run config entry is not set.
* Setting different IDs allows the EP to capture and manage multiple graphs (e.g., one per
* distinct input shape). A value of -1 means graph capture/replay should be skipped for this run.
* \return true if the graph has been captured, false otherwise.
*
* \note This function must be implemented if `OrtEp::IsGraphCaptureEnabled` is implemented and may return true.
*
* \since Version 1.26.
*/
ORT_API_T(bool, IsGraphCaptured, _In_ const OrtEp* this_ptr, _In_ int graph_annotation_id);

/** \brief Run the instantiated (captured) graph.
*
* Called by ORT instead of normal execution when `IsGraphCaptured()` returns true.
*
* \param[in] this_ptr The OrtEp instance.
* \param[in] graph_annotation_id Identifies which captured graph to replay.
* Applications can set this value via `OrtApi::AddRunConfigEntry()` with the key `"gpu_graph_id"`.
* The default value is 0 when the run config entry is not set.
* A value of -1 means graph replay should be skipped for this run.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note This function must be implemented if `OrtEp::IsGraphCaptureEnabled` is implemented and may return true.
*
* \since Version 1.26.
*/
ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);

/** \brief Get the node assignment validation policy for graph capture.
*
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is
* compatible with graph capture. This function tells ORT which validation policy to apply.
*
* \param[in] this_ptr The OrtEp instance.
* \return The node assignment policy for graph capture.
*
* \note Implementation of this function is optional. If set to NULL, ORT uses
* OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP (strictest validation).
*
* \since Version 1.26.
*/
ORT_API_T(OrtGraphCaptureNodeAssignmentPolicy, GetGraphCaptureNodeAssignmentPolicy,
_In_ const OrtEp* this_ptr);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
Compile = nullptr;
ReleaseNodeComputeInfos = nullptr;

// Graph capture/replay
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
IsGraphCaptured = IsGraphCapturedImpl;
ReplayGraph = ReplayGraphImpl;
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;

const OrtApi& ort_api = factory_.GetOrtApi();
Ort::Status log_status(ort_api.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_INFO,
"CUDA Plugin EP created",
Expand Down Expand Up @@ -304,5 +310,30 @@
EXCEPTION_TO_STATUS_END
}

bool ORT_API_CALL CudaEp::IsGraphCaptureEnabledImpl(const OrtEp* /*this_ptr*/) noexcept {
// TODO: forward to EpImpl()->IsGraphCaptureEnabled()

Check warning on line 314 in onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep.cc:314: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
return false;
}

/*static*/
bool ORT_API_CALL CudaEp::IsGraphCapturedImpl(const OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept {
// TODO: forward to EpImpl()->IsGraphCaptured(graph_annotation_id)

Check warning on line 320 in onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep.cc:320: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
return false;
}

/*static*/
OrtStatus* ORT_API_CALL CudaEp::ReplayGraphImpl(OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept {
// TODO: forward to EpImpl()->ReplayGraph(graph_annotation_id)

Check warning on line 326 in onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep.cc:326: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
return Ort::GetApi().CreateStatus(ORT_NOT_IMPLEMENTED,
"Graph capture replay is not yet supported in the CUDA plugin EP.");
}

/*static*/
OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL CudaEp::GetGraphCaptureNodeAssignmentPolicyImpl(
const OrtEp* /*this_ptr*/) noexcept {
// TODO: forward to EpImpl()->GetGraphCaptureNodeAssignmentPolicy()

Check warning on line 334 in onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep.cc:334: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}

} // namespace cuda_plugin
} // namespace onnxruntime
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ class CudaEp : public onnxruntime::ep::adapter::Ep {

static OrtStatus* ORT_API_CALL SyncImpl(OrtEp* this_ptr) noexcept;

static bool ORT_API_CALL IsGraphCaptureEnabledImpl(const OrtEp* this_ptr) noexcept;

static bool ORT_API_CALL IsGraphCapturedImpl(const OrtEp* this_ptr,
int graph_annotation_id) noexcept;

static OrtStatus* ORT_API_CALL ReplayGraphImpl(OrtEp* this_ptr,
int graph_annotation_id) noexcept;

static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
const OrtEp* this_ptr) noexcept;

CudaEpFactory& factory_;
std::string name_;
Config config_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ namespace Dml
return m_impl->ReplayGraph(graph_annotation_id);
}

OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override
{
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}

private:
ComPtr<ExecutionProviderImpl> m_impl;
};
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class JsExecutionProvider : public IExecutionProvider {
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}

private:
bool IsGraphCaptureAllowed() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,9 @@ void NvExecutionProvider::HandleCudaGraphStart(cudaStream_t stream, bool require
}

bool NvExecutionProvider::IsGraphCaptureEnabled() const {
return cuda_graph_enable_;
// Return false so that ORT's framework does not cache this EP for ORT-managed graph capture/replay.
// NvTensorRTRTX manages CUDA graph capture/replay internally.
return false;
}

bool NvExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/providers/webgpu/ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ Ep::Ep(std::unique_ptr<IExecutionProvider> impl, Factory& factory, const OrtLogg
CreateSyncStreamForDevice = nullptr; // Not stream aware
GetCompiledModelCompatibilityInfo = nullptr; // Not a compiled EP
IsConcurrentRunSupported = IsConcurrentRunSupportedImpl;
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
IsGraphCaptured = IsGraphCapturedImpl;
ReplayGraph = ReplayGraphImpl;
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
}

// OrtEp interface implementations
Expand Down Expand Up @@ -253,6 +257,34 @@ OrtStatus* ORT_API_CALL Ep::IsConcurrentRunSupportedImpl(_In_ OrtEp* /*this_ptr*
return nullptr;
}

bool ORT_API_CALL Ep::IsGraphCaptureEnabledImpl(_In_ const OrtEp* this_ptr) noexcept {
auto* ep = static_cast<const Ep*>(this_ptr);
return ep->EpImpl()->IsGraphCaptureEnabled();
}

bool ORT_API_CALL Ep::IsGraphCapturedImpl(_In_ const OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept {
auto* ep = static_cast<const Ep*>(this_ptr);
return ep->EpImpl()->IsGraphCaptured(graph_annotation_id);
}

OrtStatus* ORT_API_CALL Ep::ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
auto* ep = static_cast<Ep*>(this_ptr);
auto status = ep->EpImpl()->ReplayGraph(graph_annotation_id);
if (!status.IsOK()) {
return Api().ort.CreateStatus(static_cast<OrtErrorCode>(status.Code()),
status.ErrorMessage().c_str());
}
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}

OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL Ep::GetGraphCaptureNodeAssignmentPolicyImpl(
_In_ const OrtEp* this_ptr) noexcept {
auto* ep = static_cast<const Ep*>(this_ptr);
return ep->EpImpl()->GetGraphCaptureNodeAssignmentPolicy();
}

OrtStatus* ORT_API_CALL Ep::CreateAllocatorImpl(_In_ OrtEp* this_ptr,
_In_ const OrtMemoryInfo* memory_info,
_Outptr_result_maybenull_ OrtAllocator** allocator) noexcept {
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/webgpu/ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ class Ep : public onnxruntime::ep::adapter::Ep {
static OrtStatus* ORT_API_CALL IsConcurrentRunSupportedImpl(_In_ OrtEp* this_ptr,
_Out_ bool* is_concurrent_run_supported) noexcept;

static bool ORT_API_CALL IsGraphCaptureEnabledImpl(_In_ const OrtEp* this_ptr) noexcept;

static bool ORT_API_CALL IsGraphCapturedImpl(_In_ const OrtEp* this_ptr,
_In_ int graph_annotation_id) noexcept;

static OrtStatus* ORT_API_CALL ReplayGraphImpl(_In_ OrtEp* this_ptr,
_In_ int graph_annotation_id) noexcept;

static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
_In_ const OrtEp* this_ptr) noexcept;

Factory& factory_;
const OrtLogger& logger_;
Config config_{};
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class WebGpuExecutionProvider : public IExecutionProvider {
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
}
webgpu::BufferManager& BufferManager() const;
AllocatorPtr PrepackAllocator() const { return prepack_allocator_; }
std::span<const std::string> GetForceCpuNodeNames() const { return force_cpu_node_names_; }
Expand Down
Loading
Loading