diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 2e5d334856278..029b17eb3502e 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -3,6 +3,10 @@ #include "core/platform/windows/telemetry.h" #include +#include +#include +#include +#include #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -51,6 +55,80 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim // {3a26b1ff-7484-7484-7484-15261f42614d} (0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d), TraceLoggingOptionMicrosoftTelemetry()); + +std::string ConvertWideStringToUtf8(const std::wstring& wide) { + if (wide.empty()) + return {}; + + const UINT code_page = CP_UTF8; + const DWORD flags = 0; + LPCWCH const src = wide.data(); + const int src_len = static_cast(wide.size()); + int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr); + if (utf8_length == 0) + return {}; + + std::string utf8(utf8_length, '\0'); + if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0) + return {}; + + return utf8; +} + +std::string GetServiceNamesForCurrentProcess() { + static std::once_flag once_flag; + static std::string service_names; + + std::call_once(once_flag, [] { + SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE); + if (service_manager == nullptr) + return; + + DWORD bytes_needed = 0; + DWORD services_returned = 0; + DWORD resume_handle = 0; + if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed, + &services_returned, &resume_handle, nullptr) && + ::GetLastError() != ERROR_MORE_DATA) { + ::CloseServiceHandle(service_manager); + return; + } + + if (bytes_needed == 0) { + ::CloseServiceHandle(service_manager); + return; + } + + std::vector buffer(bytes_needed); + auto* services = reinterpret_cast(buffer.data()); + services_returned = 0; + resume_handle = 0; + if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast(services), + bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) { + ::CloseServiceHandle(service_manager); + return; + } + + DWORD current_pid = ::GetCurrentProcessId(); + std::wstring aggregated; + bool first = true; + for (DWORD i = 0; i < services_returned; ++i) { + if (services[i].ServiceStatusProcess.dwProcessId == current_pid) { + if (!first) { + aggregated.push_back(L','); + } + aggregated.append(services[i].lpServiceName); + first = false; + } + } + + ::CloseServiceHandle(service_manager); + + service_names = ConvertWideStringToUtf8(aggregated); + }); + + return service_names; +} } // namespace #ifdef _MSC_VER @@ -178,6 +256,7 @@ void WindowsTelemetry::LogProcessInfo() const { #if BUILD_INBOX isRedist = false; #endif + const std::string service_names = GetServiceNamesForCurrentProcess(); TraceLoggingWrite(telemetry_provider_handle, "ProcessInfo", TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), @@ -189,7 +268,8 @@ void WindowsTelemetry::LogProcessInfo() const { TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"), TraceLoggingBool(isRedist, "isRedist"), - TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"), + TraceLoggingString(service_names.c_str(), "serviceNames")); process_info_logged = true; } @@ -204,7 +284,8 @@ void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingUInt32(session_id, "sessionId"), - TraceLoggingLevel(WINEVENT_LEVEL_INFO)); + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogEvaluationStop(uint32_t session_id) const { @@ -278,6 +359,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio execution_provider_string += i; } + const std::string service_names = GetServiceNamesForCurrentProcess(); // Difference is MeasureEvent & isCaptureState, but keep in sync otherwise if (!captureState) { TraceLoggingWrite(telemetry_provider_handle, @@ -304,7 +386,9 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(service_names.c_str(), "serviceNames"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { TraceLoggingWrite(telemetry_provider_handle, "SessionCreation_CaptureState", @@ -330,7 +414,9 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(service_names.c_str(), "serviceNames"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } } @@ -356,7 +442,8 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), - TraceLoggingInt32(line, "line")); + TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #else TraceLoggingWrite(telemetry_provider_handle, "RuntimeError", @@ -372,7 +459,8 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), - TraceLoggingInt32(line, "line")); + TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #endif } @@ -402,7 +490,8 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(total_runs_since_last, "totalRuns"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"), - TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize")); + TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const { @@ -478,7 +567,8 @@ void WindowsTelemetry::LogAutoEpSelection(uint32_t session_id, const std::string TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(selection_policy.c_str(), "selectionPolicy"), TraceLoggingString(requested_execution_provider_string.c_str(), "requestedExecutionProviderIds"), - TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds")); + TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const std::string& provider_options_string, bool captureState) const { @@ -497,7 +587,8 @@ void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(provider_id.c_str(), "providerId"), - TraceLoggingString(provider_options_string.c_str(), "providerOptions")); + TraceLoggingString(provider_options_string.c_str(), "providerOptions"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { TraceLoggingWrite(telemetry_provider_handle, "ProviderOptions_CaptureState", @@ -509,7 +600,8 @@ void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(provider_id.c_str(), "providerId"), - TraceLoggingString(provider_options_string.c_str(), "providerOptions")); + TraceLoggingString(provider_options_string.c_str(), "providerOptions"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f79ebde61facc..f0e547f08d668 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4257,7 +4257,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.23.3", +static_assert(std::string_view(ORT_VERSION) == "1.23.4", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_23 above: