Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ struct OrtRunOptions {

onnxruntime::InlinedVector<const onnxruntime::lora::LoraAdapter*> active_adapters;

// Optional sync stream for external resource import.
// When set, the EP uses this stream for execution, enabling proper
// synchronization with imported external semaphores.
OrtSyncStream* sync_stream = nullptr;

OrtRunOptions() = default;
~OrtRunOptions() = default;
};
Expand Down
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7113,6 +7113,19 @@ struct OrtApi {
* \since Version 1.24.
*/
ORT_API2_STATUS(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);

/** \brief Sets OrtSyncStream for the run options
*
* OrtSyncStream is used to synchronize the execution of the model run for the device
* of the stream. It overrides the existing stream for the duration of the Run().
* The stream instance must be alive for the duration of the Run() call.
*
* \param[in] options
* \param[in] sync_stream The synchronization stream. Pass nullptr to clear previous setting.
*
* \since 1.24
*/
ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);
};

/*
Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,15 @@ struct RunOptions : detail::Base<OrtRunOptions> {
* \param adapter The LoraAdapter to be used as the active adapter
*/
RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter);

/** \brief Associate a sync stream with the run options.
*
* When set, the EP uses this stream for execution, enabling proper
* synchronization with imported external semaphores. Wraps OrtApi::RunOptionsSetSyncStream.
*
* \param stream The OrtSyncStream to associate with these run options. May be nullptr to clear.
*/
RunOptions& SetSyncStream(OrtSyncStream* stream);
};

namespace detail {
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,11 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter)
return *this;
}

inline RunOptions& RunOptions::SetSyncStream(OrtSyncStream* stream) {
GetApi().RunOptionsSetSyncStream(p_, stream);
return *this;
}

inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) {
ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
}
Expand Down
50 changes: 49 additions & 1 deletion onnxruntime/core/framework/device_stream_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "core/framework/device_stream_collection.h"
#include "core/framework/session_state.h"

#include <optional>

Check warning on line 8 in onnxruntime/core/framework/device_stream_collection.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: device_stream_collection.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/framework/device_stream_collection.cc:8: Found C++ system header after other header. Should be: device_stream_collection.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {

struct DummyNotification : public synchronize::Notification {
Expand Down Expand Up @@ -50,7 +52,11 @@

Status CleanUp(bool sync_streams) {
if (sync_streams) {
for (auto& device_stream : device_streams_) {
for (size_t i = 0, lim = device_streams_.size(); i < lim; ++i) {
Stream* device_stream = device_streams_[i];
if (stream_override_ && i == stream_override_->first) {
device_stream = stream_override_->second;
}
if (device_stream) {
ORT_RETURN_IF_ERROR(device_stream->CleanUpOnRunEnd());
if (is_main_graph_) {
Expand All @@ -76,11 +82,39 @@

void SetDeviceStream(size_t idx, Stream* stream) {
ORT_ENFORCE(idx < num_streams_);
if (stream_override_) {
if (idx == stream_override_->first) {
ORT_THROW("Cannot set device stream for index ", idx,
" when there is an active stream override for the same index.");
}
}
device_streams_[idx] = stream;
}

Status SetStreamOverride(Stream* stream) {
ORT_ENFORCE(stream != nullptr);
for (size_t i = 0, lim = device_streams_.size(); i < lim; ++i) {
if (device_streams_[i] != nullptr &&
// Exact match
device_streams_[i]->GetDevice() == stream->GetDevice()) {
stream_override_.emplace(i, stream);
return Status::OK();
}
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No matching stream found to override from OrtRunOptions");
}

void ResetStreamOverride() {
stream_override_.reset();
}

Stream* GetStream(size_t stream_idx) const {
ORT_ENFORCE(stream_idx < num_streams_);
if (stream_override_) {
if (stream_idx == stream_override_->first) {
return stream_override_->second;
}
}
return device_streams_[stream_idx];
}

Expand All @@ -94,6 +128,11 @@
size_t num_streams_;
std::vector<Stream*> device_streams_;
InlinedVector<std::unique_ptr<Stream>> owned_streams_;
// RunOptions allow specifying a stream override for a specific run.
// if this is present, it would be used as a stream for a given stream_id
// we declare it sepately as the original stream in device_streams_ should stay
// intact for future runs as we cache it in SessionState.
std::optional<std::pair<size_t, Stream*>> stream_override_;
const AllocatorMap& allocators_;
bool is_main_graph_ = false;
// This is used in ExecutionFrame when memory pattern is enabled, to allocate the peak size memory
Expand All @@ -117,6 +156,14 @@
impl_->SetDeviceStream(idx, stream);
}

Status DeviceStreamCollection::SetStreamOverride(Stream* stream) {
return impl_->SetStreamOverride(stream);
}

void DeviceStreamCollection::ResetStreamOverride() {
impl_->ResetStreamOverride();
}

size_t DeviceStreamCollection::NumStreams() const {
return impl_->NumStreams();
}
Expand All @@ -140,6 +187,7 @@

DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() {
if (p_) {
p_->ResetStreamOverride();
session_state_->RecycleDeviceStreamCollection(std::move(p_));
}
}
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/framework/device_stream_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ class DeviceStreamCollection {
// a EP which doesn't support Stream, i.e. CPU based EPs.
void SetDeviceStream(size_t stream_idx, Stream* stream);

// override the stream for matching device.
// only one override is allowed at a time presumably coming from
// OrtRunOptions
// returns an error if no matching stream
Status SetStreamOverride(Stream* stream);

// Remove the override before caching/reusing the collection.
void ResetStreamOverride();

// get the Stream instance on given stream index
// The return value could be nullptr, which means the EP on this
// logic sequence doesn't support Stream.
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/framework/run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* op
return nullptr;
}

ORT_API(void, OrtApis::RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream) {
options->sync_stream = sync_stream;
}

ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options,
_In_z_ const char* config_key, _In_z_ const char* config_value) {
return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value));
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "core/framework/tensor_type_and_shape.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/plugin_ep_stream.h"
#include "core/framework/transform_layout_functions.h"
#include "core/framework/utils.h"
#include "core/graph/graph_viewer.h"
Expand Down Expand Up @@ -3093,6 +3094,15 @@ Status InferenceSession::Run(const RunOptions& run_options,

#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get());
if (run_options.sync_stream != nullptr) {
if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) {
// XXX: Not tested in Parallel execution mode and disabled at this time.
LOGS(*session_logger_, WARNING) << "Setting sync stream is not supported in parallel execution mode.";
} else {
ORT_RETURN_IF_ERROR_SESSIONID_(
device_stream_collection_holder.p_->SetStreamOverride(run_options.sync_stream));
}
}
#endif

if (retval.IsOK()) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4503,14 +4503,14 @@ static constexpr OrtApi ort_api_1_to_24 = {
&OrtApis::DeviceEpIncompatibilityDetails_GetNotes,
&OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode,
&OrtApis::ReleaseDeviceEpIncompatibilityDetails,

&OrtApis::CreateEnvWithOptions,
&OrtApis::Session_GetEpGraphAssignmentInfo,
&OrtApis::EpAssignedSubgraph_GetEpName,
&OrtApis::EpAssignedSubgraph_GetNodes,
&OrtApis::EpAssignedNode_GetName,
&OrtApis::EpAssignedNode_GetDomain,
&OrtApis::EpAssignedNode_GetOperatorType,
&OrtApis::RunOptionsSetSyncStream,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ ORT_API_STATUS_IMPL(RunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const

ORT_API_STATUS_IMPL(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options);
ORT_API_STATUS_IMPL(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options);
ORT_API(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);

ORT_API_STATUS_IMPL(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
_In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type,
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3089,5 +3089,70 @@ TEST(InferenceSessionTests, GraphResolveHandlesNodeWithSubgraphBeingRemoved) {
ASSERT_STATUS_OK(session.Load(model_uri));
}

#ifdef ORT_ENABLE_STREAM
namespace {

struct TestNotification : public synchronize::Notification {
explicit TestNotification(Stream& s) : Notification(s) {}
void Activate() override {}
};

struct TestOverrideStream : Stream {
TestOverrideStream(StreamHandle h, const OrtDevice& d) : Stream(h, d) {}
std::unique_ptr<synchronize::Notification> CreateNotification(size_t /*num_consumers*/) override {
return std::make_unique<TestNotification>(*this);
}
};
} // namespace

TEST(DeviceStreamCollection, TestOverride) {
// We need an allocator map for the constructor, but it's not used in this test scenario.
AllocatorMap allocators;
DeviceStreamCollection collection(2, allocators, false);

OrtDevice cpu_device(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0);
OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0);

auto cpu_stream = std::make_unique<TestOverrideStream>(nullptr, cpu_device);
auto* cpu_stream_ptr = cpu_stream.get();
collection.AddDeviceStream(0, std::move(cpu_stream));

auto gpu_stream = std::make_unique<TestOverrideStream>(nullptr, gpu_device);
auto* gpu_stream_ptr = gpu_stream.get();
collection.AddDeviceStream(1, std::move(gpu_stream));

ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr);
ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr);

// 1. Override CPU stream
TestOverrideStream cpu_override_stream(nullptr, cpu_device);
ASSERT_STATUS_OK(collection.SetStreamOverride(&cpu_override_stream));

// Verify override took effect for correct device match
ASSERT_EQ(collection.GetStream(0), &cpu_override_stream);
ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr);

// 2. Reset Override
collection.ResetStreamOverride();
ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr);
ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr);

// 3. Override GPU stream
TestOverrideStream gpu_override_stream(nullptr, gpu_device);
ASSERT_STATUS_OK(collection.SetStreamOverride(&gpu_override_stream));

ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr);
ASSERT_EQ(collection.GetStream(1), &gpu_override_stream);

collection.ResetStreamOverride();

// 4. Override with non-matching device
OrtDevice other_device(OrtDevice::FPGA, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0);
TestOverrideStream other_stream(nullptr, other_device);
ASSERT_FALSE(collection.SetStreamOverride(&other_stream).IsOK());
}

#endif // ORT_ENABLE_STREAM

} // namespace test
} // namespace onnxruntime
Loading
Loading