Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 102 additions & 10 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#include "core/platform/windows/telemetry.h"
#include <mutex>
#include <string>
#include <vector>
#include <cwchar>
#include <winsvc.h>

Check warning on line 9 in onnxruntime/core/platform/windows/telemetry.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: telemetry.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/platform/windows/telemetry.cc:9: Found C system header after C++ system header. Should be: telemetry.h, c system, c++ system, other. [build/include_order] [4]
#include "core/common/logging/logging.h"
#include "onnxruntime_config.h"

Expand Down Expand Up @@ -51,6 +55,80 @@
// {3a26b1ff-7484-7484-7484-15261f42614d}
(0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d),
TraceLoggingOptionMicrosoftTelemetry());

std::string ConvertWideStringToUtf8(const std::wstring& wide) {
if (wide.empty())
return {};

const UINT code_page = CP_UTF8;
const DWORD flags = 0;
LPCWCH const src = wide.data();
const int src_len = static_cast<int>(wide.size());
int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr);
if (utf8_length == 0)
return {};

std::string utf8(utf8_length, '\0');
if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0)
return {};

return utf8;
}

std::string GetServiceNamesForCurrentProcess() {
static std::once_flag once_flag;
static std::string service_names;

std::call_once(once_flag, [] {
SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE);
if (service_manager == nullptr)
return;

DWORD bytes_needed = 0;
DWORD services_returned = 0;
DWORD resume_handle = 0;
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed,
&services_returned, &resume_handle, nullptr) &&
::GetLastError() != ERROR_MORE_DATA) {
::CloseServiceHandle(service_manager);
return;
}

if (bytes_needed == 0) {
::CloseServiceHandle(service_manager);
return;
}

std::vector<uint8_t> buffer(bytes_needed);
auto* services = reinterpret_cast<ENUM_SERVICE_STATUS_PROCESSW*>(buffer.data());
services_returned = 0;
resume_handle = 0;
if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast<LPBYTE>(services),
bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) {
::CloseServiceHandle(service_manager);
return;
}

DWORD current_pid = ::GetCurrentProcessId();
std::wstring aggregated;
bool first = true;
for (DWORD i = 0; i < services_returned; ++i) {
if (services[i].ServiceStatusProcess.dwProcessId == current_pid) {
if (!first) {
aggregated.push_back(L',');
}
aggregated.append(services[i].lpServiceName);
first = false;
}
}

::CloseServiceHandle(service_manager);

service_names = ConvertWideStringToUtf8(aggregated);
});

return service_names;
}
} // namespace

#ifdef _MSC_VER
Expand Down Expand Up @@ -178,6 +256,7 @@
#if BUILD_INBOX
isRedist = false;
#endif
const std::string service_names = GetServiceNamesForCurrentProcess();
TraceLoggingWrite(telemetry_provider_handle,
"ProcessInfo",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
Expand All @@ -189,7 +268,8 @@
TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"),
TraceLoggingBool(isRedist, "isRedist"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"),
TraceLoggingString(service_names.c_str(), "serviceNames"));

process_info_logged = true;
}
Expand All @@ -204,7 +284,8 @@
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingLevel(WINEVENT_LEVEL_INFO));
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}

void WindowsTelemetry::LogEvaluationStop(uint32_t session_id) const {
Expand Down Expand Up @@ -278,6 +359,7 @@
execution_provider_string += i;
}

const std::string service_names = GetServiceNamesForCurrentProcess();
// Difference is MeasureEvent & isCaptureState, but keep in sync otherwise
if (!captureState) {
TraceLoggingWrite(telemetry_provider_handle,
Expand All @@ -304,7 +386,9 @@
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
TraceLoggingString(service_names.c_str(), "serviceNames"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
} else {
TraceLoggingWrite(telemetry_provider_handle,
"SessionCreation_CaptureState",
Expand All @@ -330,7 +414,9 @@
TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
TraceLoggingString(service_names.c_str(), "serviceNames"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
}

Expand All @@ -356,7 +442,8 @@
TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"),
TraceLoggingString(file, "file"),
TraceLoggingString(function, "function"),
TraceLoggingInt32(line, "line"));
TraceLoggingInt32(line, "line"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
#else
TraceLoggingWrite(telemetry_provider_handle,
"RuntimeError",
Expand All @@ -372,7 +459,8 @@
TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"),
TraceLoggingString(file, "file"),
TraceLoggingString(function, "function"),
TraceLoggingInt32(line, "line"));
TraceLoggingInt32(line, "line"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
#endif
}

Expand Down Expand Up @@ -402,7 +490,8 @@
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingUInt32(total_runs_since_last, "totalRuns"),
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"),
TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"));
TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}

void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const {
Expand Down Expand Up @@ -478,7 +567,8 @@
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingString(selection_policy.c_str(), "selectionPolicy"),
TraceLoggingString(requested_execution_provider_string.c_str(), "requestedExecutionProviderIds"),
TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds"));
TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}

void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const std::string& provider_options_string, bool captureState) const {
Expand All @@ -497,7 +587,8 @@
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingString(provider_id.c_str(), "providerId"),
TraceLoggingString(provider_options_string.c_str(), "providerOptions"));
TraceLoggingString(provider_options_string.c_str(), "providerOptions"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
} else {
TraceLoggingWrite(telemetry_provider_handle,
"ProviderOptions_CaptureState",
Expand All @@ -509,7 +600,8 @@
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingString(provider_id.c_str(), "providerId"),
TraceLoggingString(provider_options_string.c_str(), "providerOptions"));
TraceLoggingString(provider_options_string.c_str(), "providerOptions"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4257,7 +4257,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");

// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.23.3",
static_assert(std::string_view(ORT_VERSION) == "1.23.4",
"ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly");
// 1. Update the hardcoded version string in above static_assert to silence it
// 2. If there were any APIs added to ort_api_1_to_23 above:
Expand Down
Loading