diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index e63ab044834f5..001fa158345ab 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -51,6 +51,11 @@ struct OrtRunOptions { onnxruntime::InlinedVector active_adapters; + // Optional sync stream for external resource import. + // When set, the EP uses this stream for execution, enabling proper + // synchronization with imported external semaphores. + OrtSyncStream* sync_stream = nullptr; + OrtRunOptions() = default; ~OrtRunOptions() = default; }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2baa770af94af..9e14c21d8997a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7113,6 +7113,19 @@ struct OrtApi { * \since Version 1.24. */ ORT_API2_STATUS(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + + /** \brief Sets OrtSyncStream for the run options + * + * OrtSyncStream is used to synchronize the execution of the model run for the device + * of the stream. It overrides the existing stream for the duration of the Run(). + * The stream instance must be alive for the duration of the Run() call. + * + * \param[in] options + * \param[in] sync_stream The synchronization stream. Pass nullptr to clear previous setting. + * + * \since 1.24 + */ + ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index aafabe7bc2cca..0efca9eaa928e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1340,6 +1340,15 @@ struct RunOptions : detail::Base { * \param adapter The LoraAdapter to be used as the active adapter */ RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); + + /** \brief Associate a sync stream with the run options. + * + * When set, the EP uses this stream for execution, enabling proper + * synchronization with imported external semaphores. Wraps OrtApi::RunOptionsSetSyncStream. + * + * \param stream The OrtSyncStream to associate with these run options. May be nullptr to clear. + */ + RunOptions& SetSyncStream(OrtSyncStream* stream); }; namespace detail { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 1d6d839c3a116..0249a2bd8e0c9 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1058,6 +1058,11 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) return *this; } +inline RunOptions& RunOptions::SetSyncStream(OrtSyncStream* stream) { + GetApi().RunOptionsSetSyncStream(p_, stream); + return *this; +} + inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) { ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_)); } diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc index a32973ddb8c9e..76da5702634aa 100644 --- a/onnxruntime/core/framework/device_stream_collection.cc +++ b/onnxruntime/core/framework/device_stream_collection.cc @@ -5,6 +5,8 @@ #include "core/framework/device_stream_collection.h" #include "core/framework/session_state.h" +#include + namespace onnxruntime { struct DummyNotification : public synchronize::Notification { @@ -50,7 +52,11 @@ class DeviceStreamCollectionImpl { Status CleanUp(bool sync_streams) { if (sync_streams) { - for (auto& device_stream : device_streams_) { + for (size_t i = 0, lim = device_streams_.size(); i < lim; ++i) { + Stream* device_stream = device_streams_[i]; + if (stream_override_ && i == stream_override_->first) { + device_stream = stream_override_->second; + } if (device_stream) { ORT_RETURN_IF_ERROR(device_stream->CleanUpOnRunEnd()); if (is_main_graph_) { @@ -76,11 +82,39 @@ class DeviceStreamCollectionImpl { void SetDeviceStream(size_t idx, Stream* stream) { ORT_ENFORCE(idx < num_streams_); + if (stream_override_) { + if (idx == stream_override_->first) { + ORT_THROW("Cannot set device stream for index ", idx, + " when there is an active stream override for the same index."); + } + } device_streams_[idx] = stream; } + Status SetStreamOverride(Stream* stream) { + ORT_ENFORCE(stream != nullptr); + for (size_t i = 0, lim = device_streams_.size(); i < lim; ++i) { + if (device_streams_[i] != nullptr && + // Exact match + device_streams_[i]->GetDevice() == stream->GetDevice()) { + stream_override_.emplace(i, stream); + return Status::OK(); + } + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No matching stream found to override from OrtRunOptions"); + } + + void ResetStreamOverride() { + stream_override_.reset(); + } + Stream* GetStream(size_t stream_idx) const { ORT_ENFORCE(stream_idx < num_streams_); + if (stream_override_) { + if (stream_idx == stream_override_->first) { + return stream_override_->second; + } + } return device_streams_[stream_idx]; } @@ -94,6 +128,11 @@ class DeviceStreamCollectionImpl { size_t num_streams_; std::vector device_streams_; InlinedVector> owned_streams_; + // RunOptions allow specifying a stream override for a specific run. + // if this is present, it would be used as a stream for a given stream_id + // we declare it sepately as the original stream in device_streams_ should stay + // intact for future runs as we cache it in SessionState. + std::optional> stream_override_; const AllocatorMap& allocators_; bool is_main_graph_ = false; // This is used in ExecutionFrame when memory pattern is enabled, to allocate the peak size memory @@ -117,6 +156,14 @@ void DeviceStreamCollection::SetDeviceStream(size_t idx, Stream* stream) { impl_->SetDeviceStream(idx, stream); } +Status DeviceStreamCollection::SetStreamOverride(Stream* stream) { + return impl_->SetStreamOverride(stream); +} + +void DeviceStreamCollection::ResetStreamOverride() { + impl_->ResetStreamOverride(); +} + size_t DeviceStreamCollection::NumStreams() const { return impl_->NumStreams(); } @@ -140,6 +187,7 @@ DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* s DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() { if (p_) { + p_->ResetStreamOverride(); session_state_->RecycleDeviceStreamCollection(std::move(p_)); } } diff --git a/onnxruntime/core/framework/device_stream_collection.h b/onnxruntime/core/framework/device_stream_collection.h index c76c7c731571c..34d2ecba13476 100644 --- a/onnxruntime/core/framework/device_stream_collection.h +++ b/onnxruntime/core/framework/device_stream_collection.h @@ -28,6 +28,15 @@ class DeviceStreamCollection { // a EP which doesn't support Stream, i.e. CPU based EPs. void SetDeviceStream(size_t stream_idx, Stream* stream); + // override the stream for matching device. + // only one override is allowed at a time presumably coming from + // OrtRunOptions + // returns an error if no matching stream + Status SetStreamOverride(Stream* stream); + + // Remove the override before caching/reusing the collection. + void ResetStreamOverride(); + // get the Stream instance on given stream index // The return value could be nullptr, which means the EP on this // logic sequence doesn't support Stream. diff --git a/onnxruntime/core/framework/run_options.cc b/onnxruntime/core/framework/run_options.cc index 0a2bb9507ac85..45635e973d09d 100644 --- a/onnxruntime/core/framework/run_options.cc +++ b/onnxruntime/core/framework/run_options.cc @@ -58,6 +58,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* op return nullptr; } +ORT_API(void, OrtApis::RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream) { + options->sync_stream = sync_stream; +} + ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options, _In_z_ const char* config_key, _In_z_ const char* config_value) { return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value)); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index a2966e2cc96f2..0c9b3c0663b5c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -34,6 +34,7 @@ #include "core/framework/tensor_type_and_shape.h" #include "core/framework/op_kernel_context_internal.h" #include "core/framework/ort_value_pattern_planner.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/transform_layout_functions.h" #include "core/framework/utils.h" #include "core/graph/graph_viewer.h" @@ -3093,6 +3094,15 @@ Status InferenceSession::Run(const RunOptions& run_options, #ifdef ORT_ENABLE_STREAM DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get()); + if (run_options.sync_stream != nullptr) { + if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { + // XXX: Not tested in Parallel execution mode and disabled at this time. + LOGS(*session_logger_, WARNING) << "Setting sync stream is not supported in parallel execution mode."; + } else { + ORT_RETURN_IF_ERROR_SESSIONID_( + device_stream_collection_holder.p_->SetStreamOverride(run_options.sync_stream)); + } + } #endif if (retval.IsOK()) { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 09672306b1314..d4dbb58c96d7c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4503,7 +4503,6 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::DeviceEpIncompatibilityDetails_GetNotes, &OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode, &OrtApis::ReleaseDeviceEpIncompatibilityDetails, - &OrtApis::CreateEnvWithOptions, &OrtApis::Session_GetEpGraphAssignmentInfo, &OrtApis::EpAssignedSubgraph_GetEpName, @@ -4511,6 +4510,7 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::EpAssignedNode_GetName, &OrtApis::EpAssignedNode_GetDomain, &OrtApis::EpAssignedNode_GetOperatorType, + &OrtApis::RunOptionsSetSyncStream, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 2f2d7fa1dfcf4..bbe5f3db388b5 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -122,6 +122,7 @@ ORT_API_STATUS_IMPL(RunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const ORT_API_STATUS_IMPL(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); ORT_API_STATUS_IMPL(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); +ORT_API(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); ORT_API_STATUS_IMPL(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 07f2cc8581ed5..1c4e7800b7d2e 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -3089,5 +3089,70 @@ TEST(InferenceSessionTests, GraphResolveHandlesNodeWithSubgraphBeingRemoved) { ASSERT_STATUS_OK(session.Load(model_uri)); } +#ifdef ORT_ENABLE_STREAM +namespace { + +struct TestNotification : public synchronize::Notification { + explicit TestNotification(Stream& s) : Notification(s) {} + void Activate() override {} +}; + +struct TestOverrideStream : Stream { + TestOverrideStream(StreamHandle h, const OrtDevice& d) : Stream(h, d) {} + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override { + return std::make_unique(*this); + } +}; +} // namespace + +TEST(DeviceStreamCollection, TestOverride) { + // We need an allocator map for the constructor, but it's not used in this test scenario. + AllocatorMap allocators; + DeviceStreamCollection collection(2, allocators, false); + + OrtDevice cpu_device(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0); + OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); + + auto cpu_stream = std::make_unique(nullptr, cpu_device); + auto* cpu_stream_ptr = cpu_stream.get(); + collection.AddDeviceStream(0, std::move(cpu_stream)); + + auto gpu_stream = std::make_unique(nullptr, gpu_device); + auto* gpu_stream_ptr = gpu_stream.get(); + collection.AddDeviceStream(1, std::move(gpu_stream)); + + ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr); + ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr); + + // 1. Override CPU stream + TestOverrideStream cpu_override_stream(nullptr, cpu_device); + ASSERT_STATUS_OK(collection.SetStreamOverride(&cpu_override_stream)); + + // Verify override took effect for correct device match + ASSERT_EQ(collection.GetStream(0), &cpu_override_stream); + ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr); + + // 2. Reset Override + collection.ResetStreamOverride(); + ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr); + ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr); + + // 3. Override GPU stream + TestOverrideStream gpu_override_stream(nullptr, gpu_device); + ASSERT_STATUS_OK(collection.SetStreamOverride(&gpu_override_stream)); + + ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr); + ASSERT_EQ(collection.GetStream(1), &gpu_override_stream); + + collection.ResetStreamOverride(); + + // 4. Override with non-matching device + OrtDevice other_device(OrtDevice::FPGA, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0); + TestOverrideStream other_stream(nullptr, other_device); + ASSERT_FALSE(collection.SetStreamOverride(&other_stream).IsOK()); +} + +#endif // ORT_ENABLE_STREAM + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 35bf5349890e4..a96a2c48b4ca6 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,7 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/graph/constants.h" +#include "core/framework/plugin_ep_stream.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_lite_custom_op.h" @@ -4828,6 +4830,97 @@ TEST(CApiTest, ModelWithExternalDataOutsideModelDirectoryShouldFailToLoad) { << "Exception message should indicate external data or security issue. Got: " << exception_message; } +#ifdef ORT_ENABLE_STREAM +#if USE_CUDA + +namespace { +struct TestCudaStreamOverrideUsed : onnxruntime::Stream { + TestCudaStreamOverrideUsed(onnxruntime::Stream* stream) + : onnxruntime::Stream(stream->GetHandle(), stream->GetDevice()), real_stream(stream) {} + + std::unique_ptr CreateNotification(size_t num_consumers) override { + return real_stream->CreateNotification(num_consumers); + } + + TestCudaStreamOverrideUsed(const TestCudaStreamOverrideUsed&) = delete; + TestCudaStreamOverrideUsed& operator=(const TestCudaStreamOverrideUsed&) = delete; + + void Flush() override { + flush_count++; + real_stream->Flush(); + } + + onnxruntime::Status CleanUpOnRunEnd() override { return real_stream->CleanUpOnRunEnd(); } + + onnxruntime::Stream* real_stream; + size_t flush_count{0}; +}; +} // namespace + +TEST(CApiTest, TestSyncStreamOverride) { +#ifdef _WIN32 + auto cuda_lib = ORT_TSTR("onnxruntime_providers_cuda.dll"); +#else + auto cuda_lib = ORT_TSTR("onnxruntime_providers_cuda.so"); +#endif + + if (!std::filesystem::exists(cuda_lib)) { + GTEST_SKIP() << "CUDA library was not found"; + } + + constexpr const char* cuda_ep_name = "ORT Cuda"; + ort_env->RegisterExecutionProviderLibrary(cuda_ep_name, cuda_lib); + auto ep_devices = ort_env->GetEpDevices(); + + Ort::ConstEpDevice cuda_device; + for (const auto& device : ep_devices) { + if (device.Device().Type() == OrtHardwareDeviceType_GPU && + device.Device().VendorId() == 0x10DE) { // NVIDIA vendor ID + cuda_device = device; + break; + } + } + + if (!cuda_device) { + GTEST_SKIP() << "No CUDA device found, skipping test."; + } + + // Create session with CUDA EP using C++ public API in Ort:: namespace + { + // Create a stream on CUDA Device + const auto sync_stream = cuda_device.CreateSyncStream(); + TestCudaStreamOverrideUsed cuda_override_stream(sync_stream); + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, {cuda_device}, Ort::KeyValuePairs{}); + + Ort::Session session(*ort_env, MODEL_URI, session_options); + + constexpr const std::array input_names = {"X"}; + constexpr const std::array output_names = {"Y"}; + constexpr const std::array input_shape = {3LL, 2LL}; + float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto input_value = Ort::Value::CreateTensor( + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), + x_value, std::size(x_value), input_shape.data(), input_shape.size()); + Ort::Value ort_inputs[] = {std::move(input_value)}; + + Ort::RunOptions run_options; + run_options.SetSyncStream(reinterpret_cast(&cuda_override_stream)); + + auto output_values = session.Run(run_options, + input_names.data(), ort_inputs, std::size(ort_inputs), + output_names.data(), output_names.size()); + + ASSERT_GT(cuda_override_stream.flush_count, 0U) + << "Expected the custom CUDA stream override to be used during session run."; + } + + ort_env->UnregisterExecutionProviderLibrary(cuda_ep_name); +} +#endif +#endif + #if !defined(ORT_MINIMAL_BUILD) TEST(CApiTest, GetEpGraphAssignmentInfo_NotEnabledError) { // Test that calling OrtApi::Session_GetEpGraphAssignmentInfo() without enabling the appropriate