diff --git a/CMakeLists.txt b/CMakeLists.txt index bb8bc14288..0fce41fdfe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,6 +80,11 @@ if(ENABLE_TESTS) endif() endif() +if(ENABLE_TRACING) + message(STATUS "Tracing is enabled.") + add_compile_definitions(ORTGENAI_ENABLE_TRACING) +endif() + find_package(Threads REQUIRED) if(WIN32) diff --git a/cmake/options.cmake b/cmake/options.cmake index d4ddbce8b4..67286bfe41 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -17,3 +17,6 @@ option(TEST_PHI2 "Enable tests for Phi2" OFF) # performance option(ENABLE_MODEL_BENCHMARK "Build model benchmark program" ON) + +# diagnostics +option(ENABLE_TRACING "Enable recording of tracing data" OFF) diff --git a/src/generators.cpp b/src/generators.cpp index f7d2bdb103..8754c35a4a 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -8,6 +8,7 @@ #include "models/decoder_only.h" #include "constrained_logits_processor.h" #include "search.h" +#include "tracing.h" #include "cpu/interface.h" #include "cuda/interface.h" #include "dml/interface.h" @@ -47,7 +48,7 @@ static bool _ = (Ort::InitApi(), false); static OrtLoggingLevel GetDefaultOrtLoggingLevel() { bool ort_verbose_logging = false; - GetEnvironmentVariable("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging); + GetEnv("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging); return ort_verbose_logging ? OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE : OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR; } @@ -347,6 +348,8 @@ void Generator::AuxAppendTokens(cpu_span input_ids) { } void Generator::AppendTokens(cpu_span input_ids) { + DurationTrace trace{"Generator::AppendTokens"}; + ThrowErrorIfSessionTerminated(state_->session_terminated_); if (input_ids.size() == 0) throw std::runtime_error("input_ids is empty"); @@ -434,6 +437,8 @@ void Generator::SetLogits(DeviceSpan logits) { } void Generator::GenerateNextToken() { + DurationTrace trace{"Generator::GenerateNextToken"}; + ThrowErrorIfSessionTerminated(state_->session_terminated_); if (search_->GetSequenceLength() == 0 && !computed_logits_) throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken."); diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 602596bda6..2bcbbdea7c 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -3,6 +3,7 @@ #include "../generators.h" #include "../logging.h" +#include "../tracing.h" #include "decoder_only_pipeline.h" #include "windowed_kv_cache.h" @@ -165,6 +166,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan continue; } + DurationTrace trace{MakeString("DecoderOnlyPipelineState::RunPipeline[", pipeline_state->id_, "]")}; + if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx > -1) { if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx >= static_cast(model_.sessions_.size())) { @@ -293,6 +296,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan DeviceSpan DecoderOnlyPipelineState::Run(int total_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { + DurationTrace trace{"DecoderOnlyPipelineState::Run"}; + UpdateInputsOutputs(next_tokens, next_indices, total_length); size_t num_chunks{1}; diff --git a/src/models/env_utils.cpp b/src/models/env_utils.cpp index d4d14634cf..5260f1f0fd 100644 --- a/src/models/env_utils.cpp +++ b/src/models/env_utils.cpp @@ -11,7 +11,7 @@ namespace Generators { -std::string GetEnvironmentVariable(const char* var_name) { +std::string GetEnv(const char* var_name) { #if _MSC_VER // Why getenv() should be avoided on Windows: // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv @@ -40,8 +40,8 @@ std::string GetEnvironmentVariable(const char* var_name) { #endif // _MSC_VER } -void GetEnvironmentVariable(const char* var_name, bool& value) { - std::string str_value = GetEnvironmentVariable(var_name); +void GetEnv(const char* var_name, bool& value) { + std::string str_value = GetEnv(var_name); if (str_value == "1" || str_value == "true") { value = true; } else if (str_value == "0" || str_value == "false") { diff --git a/src/models/env_utils.h b/src/models/env_utils.h index d436bedda1..dc204c632b 100644 --- a/src/models/env_utils.h +++ b/src/models/env_utils.h @@ -5,11 +5,12 @@ namespace Generators { -std::string GetEnvironmentVariable(const char* var_name); +// Gets the environment variable value. If no environment variable is found, the result will be empty. +std::string GetEnv(const char* var_name); // This overload is used to get boolean environment variables. // If the environment variable is set to "1" or "true" (case-sensitive), value will be set to true. // Otherwise, value will not be modified. -void GetEnvironmentVariable(const char* var_name, bool& value); +void GetEnv(const char* var_name, bool& value); } // namespace Generators diff --git a/src/models/model.cpp b/src/models/model.cpp index 56c49d3214..1551005bd8 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -11,6 +11,7 @@ #include "../generators.h" #include "../search.h" +#include "../tracing.h" #include "model.h" #include "gpt.h" #include "decoder_only.h" @@ -36,6 +37,8 @@ State::State(const GeneratorParams& params, const Model& model) } void State::Run(OrtSession& session, bool graph_capture_this_run) { + DurationTrace trace{"State::Run"}; + if (params_->use_graph_capture) { if (graph_capture_this_run) run_options_->AddConfigEntry("gpu_graph_id", graph_id_.c_str()); diff --git a/src/models/onnxruntime_api.h b/src/models/onnxruntime_api.h index 0e9593400d..b9716d5aee 100644 --- a/src/models/onnxruntime_api.h +++ b/src/models/onnxruntime_api.h @@ -200,7 +200,7 @@ inline void InitApi() { } bool ort_lib = false; - Generators::GetEnvironmentVariable("ORTGENAI_LOG_ORT_LIB", ort_lib); + Generators::GetEnv("ORTGENAI_LOG_ORT_LIB", ort_lib); if (ort_lib) { Generators::SetLogBool("enabled", true); Generators::SetLogBool("ort_lib", true); diff --git a/src/tracing.cpp b/src/tracing.cpp new file mode 100644 index 0000000000..74b871d267 --- /dev/null +++ b/src/tracing.cpp @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "tracing.h" + +#include +#include +#include +#include +#include +#include + +#include "models/env_utils.h" + +namespace Generators { + +#if defined(ORTGENAI_ENABLE_TRACING) + +namespace { + +// Writes trace events to a file in Chrome tracing format. +// See more details about the format here: +// https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU +class FileTraceSink : public TraceSink { + public: + FileTraceSink(std::string_view file_path) + : ostream_{std::ofstream{file_path.data()}}, + start_{Clock::now()}, + insert_event_delimiter_{false} { + ostream_ << "["; + } + + ~FileTraceSink() { + ostream_ << "]\n"; + } + + void BeginDuration(std::string_view label) { + LogEvent("B", label); + } + + void EndDuration() { + LogEvent("E"); + } + + private: + using Clock = std::chrono::steady_clock; + + void LogEvent(std::string_view phase_type, std::optional label = std::nullopt) { + const auto thread_id = std::this_thread::get_id(); + const auto ts = std::chrono::duration_cast(Clock::now() - start_); + + std::ostringstream event{}; + + event << "{"; + + if (label.has_value()) { + event << "\"name\": \"" << *label << "\", "; + } + + event << "\"cat\": \"perf\", " + << "\"ph\": \"" << phase_type << "\", " + << "\"pid\": 0, " + << "\"tid\": " << thread_id << ", " + << "\"ts\": " << ts.count() + << "}"; + + { + std::scoped_lock g{output_mutex_}; + + // add the delimiter only after writing the first event + if (insert_event_delimiter_) { + ostream_ << ",\n"; + } else { + insert_event_delimiter_ = true; + } + + ostream_ << event.str(); + } + } + + std::ofstream ostream_; + const Clock::time_point start_; + bool insert_event_delimiter_; + + std::mutex output_mutex_; +}; + +std::string GetTraceFileName() { + constexpr const char* kTraceFileEnvironmentVariableName = "ORTGENAI_TRACE_FILE_PATH"; + auto trace_file_name = GetEnv(kTraceFileEnvironmentVariableName); + if (trace_file_name.empty()) { + trace_file_name = "ortgenai_trace.log"; + } + return trace_file_name; +} + +} // namespace + +#endif // defined(ORTGENAI_ENABLE_TRACING) + +Tracer::Tracer() { +#if defined(ORTGENAI_ENABLE_TRACING) + const auto trace_file_name = GetTraceFileName(); + sink_ = std::make_unique(trace_file_name); +#endif +} + +void Tracer::BeginDuration(std::string_view label) { +#if defined(ORTGENAI_ENABLE_TRACING) + sink_->BeginDuration(label); +#else + static_cast(label); +#endif +} + +void Tracer::EndDuration() { +#if defined(ORTGENAI_ENABLE_TRACING) + sink_->EndDuration(); +#endif +} + +Tracer& DefaultTracerInstance() { + static auto tracer = Tracer{}; + return tracer; +} + +} // namespace Generators diff --git a/src/tracing.h b/src/tracing.h new file mode 100644 index 0000000000..94132277a1 --- /dev/null +++ b/src/tracing.h @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Build with CMake option ENABLE_TRACING=ON to enable tracing. +// To avoid performance overhead, tracing is not enabled by default. + +// When tracing is enabled, the trace data will be recorded to a file. +// The trace file path can be specified with the environment variable ORTGENAI_TRACE_FILE_PATH. +// The trace file can be viewed with Perfetto UI (https://ui.perfetto.dev/). + +#pragma once + +#include +#include + +namespace Generators { + +// Trace consumer interface. +class TraceSink { + public: + virtual void BeginDuration(std::string_view label) = 0; + virtual void EndDuration() = 0; + virtual ~TraceSink() = default; +}; + +// Main tracing class. +class Tracer { + public: + Tracer(); + + // Begins a traced duration with the given label. + void BeginDuration(std::string_view label); + + // Ends the traced duration from the most recent call to BeginDuration() in the same thread. + void EndDuration(); + + private: + Tracer(const Tracer&) = delete; + Tracer& operator=(const Tracer&) = delete; + Tracer(Tracer&&) = delete; + Tracer& operator=(Tracer&&) = delete; + +#if defined(ORTGENAI_ENABLE_TRACING) + std::unique_ptr sink_; +#endif +}; + +// Gets the default tracer instance. +Tracer& DefaultTracerInstance(); + +// Records a traced duration while in scope. +class DurationTrace { + public: + [[nodiscard]] DurationTrace(std::string_view label) + : DurationTrace{DefaultTracerInstance(), label} { + } + + [[nodiscard]] DurationTrace(Tracer& tracer, std::string_view label) + : tracer_{tracer} { + tracer_.BeginDuration(label); + } + + ~DurationTrace() { + tracer_.EndDuration(); + } + + private: + DurationTrace(const DurationTrace&) = delete; + DurationTrace& operator=(const DurationTrace&) = delete; + DurationTrace(DurationTrace&&) = delete; + DurationTrace& operator=(DurationTrace&&) = delete; + + Tracer& tracer_; +}; + +} // namespace Generators