Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ if(ENABLE_TESTS)
endif()
endif()

if(ENABLE_TRACING)
message(STATUS "Tracing is enabled.")
add_compile_definitions(ORTGENAI_ENABLE_TRACING)
endif()

find_package(Threads REQUIRED)

if(WIN32)
Expand Down
3 changes: 3 additions & 0 deletions cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ option(TEST_PHI2 "Enable tests for Phi2" OFF)

# performance
option(ENABLE_MODEL_BENCHMARK "Build model benchmark program" ON)

# diagnostics
option(ENABLE_TRACING "Enable recording of tracing data" OFF)
7 changes: 6 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "models/decoder_only.h"
#include "constrained_logits_processor.h"
#include "search.h"
#include "tracing.h"
#include "cpu/interface.h"
#include "cuda/interface.h"
#include "dml/interface.h"
Expand Down Expand Up @@ -47,7 +48,7 @@ static bool _ = (Ort::InitApi(), false);

static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
bool ort_verbose_logging = false;
GetEnvironmentVariable("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging);
GetEnv("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging);
return ort_verbose_logging ? OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE : OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;
}

Expand Down Expand Up @@ -347,6 +348,8 @@ void Generator::AuxAppendTokens(cpu_span<const int32_t> input_ids) {
}

void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
DurationTrace trace{"Generator::AppendTokens"};

ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (input_ids.size() == 0)
throw std::runtime_error("input_ids is empty");
Expand Down Expand Up @@ -434,6 +437,8 @@ void Generator::SetLogits(DeviceSpan<float> logits) {
}

void Generator::GenerateNextToken() {
DurationTrace trace{"Generator::GenerateNextToken"};

ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (search_->GetSequenceLength() == 0 && !computed_logits_)
throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken.");
Expand Down
5 changes: 5 additions & 0 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "../generators.h"
#include "../logging.h"
#include "../tracing.h"
#include "decoder_only_pipeline.h"
#include "windowed_kv_cache.h"

Expand Down Expand Up @@ -165,6 +166,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
continue;
}

DurationTrace trace{MakeString("DecoderOnlyPipelineState::RunPipeline[", pipeline_state->id_, "]")};

if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx > -1) {
if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx >=
static_cast<int>(model_.sessions_.size())) {
Expand Down Expand Up @@ -293,6 +296,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>

DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
DurationTrace trace{"DecoderOnlyPipelineState::Run"};

UpdateInputsOutputs(next_tokens, next_indices, total_length);

size_t num_chunks{1};
Expand Down
6 changes: 3 additions & 3 deletions src/models/env_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace Generators {

std::string GetEnvironmentVariable(const char* var_name) {
std::string GetEnv(const char* var_name) {
#if _MSC_VER
// Why getenv() should be avoided on Windows:
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv
Expand Down Expand Up @@ -40,8 +40,8 @@ std::string GetEnvironmentVariable(const char* var_name) {
#endif // _MSC_VER
}

void GetEnvironmentVariable(const char* var_name, bool& value) {
std::string str_value = GetEnvironmentVariable(var_name);
void GetEnv(const char* var_name, bool& value) {
std::string str_value = GetEnv(var_name);
if (str_value == "1" || str_value == "true") {
value = true;
} else if (str_value == "0" || str_value == "false") {
Expand Down
5 changes: 3 additions & 2 deletions src/models/env_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

namespace Generators {

std::string GetEnvironmentVariable(const char* var_name);
// Gets the environment variable value. If no environment variable is found, the result will be empty.
std::string GetEnv(const char* var_name);

// This overload is used to get boolean environment variables.
// If the environment variable is set to "1" or "true" (case-sensitive), value will be set to true.
// Otherwise, value will not be modified.
void GetEnvironmentVariable(const char* var_name, bool& value);
void GetEnv(const char* var_name, bool& value);

} // namespace Generators
3 changes: 3 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "../generators.h"
#include "../search.h"
#include "../tracing.h"
#include "model.h"
#include "gpt.h"
#include "decoder_only.h"
Expand All @@ -36,6 +37,8 @@ State::State(const GeneratorParams& params, const Model& model)
}

void State::Run(OrtSession& session, bool graph_capture_this_run) {
DurationTrace trace{"State::Run"};

if (params_->use_graph_capture) {
if (graph_capture_this_run)
run_options_->AddConfigEntry("gpu_graph_id", graph_id_.c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ inline void InitApi() {
}

bool ort_lib = false;
Generators::GetEnvironmentVariable("ORTGENAI_LOG_ORT_LIB", ort_lib);
Generators::GetEnv("ORTGENAI_LOG_ORT_LIB", ort_lib);
if (ort_lib) {
Generators::SetLogBool("enabled", true);
Generators::SetLogBool("ort_lib", true);
Expand Down
127 changes: 127 additions & 0 deletions src/tracing.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "tracing.h"

#include <chrono>
#include <fstream>
#include <mutex>
#include <optional>
#include <sstream>
#include <thread>

#include "models/env_utils.h"

namespace Generators {

#if defined(ORTGENAI_ENABLE_TRACING)

namespace {

// Writes trace events to a file in Chrome tracing format.
// See more details about the format here:
// https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU
class FileTraceSink : public TraceSink {
public:
FileTraceSink(std::string_view file_path)
: ostream_{std::ofstream{file_path.data()}},
start_{Clock::now()},
insert_event_delimiter_{false} {
ostream_ << "[";
}

~FileTraceSink() {
ostream_ << "]\n";
}

void BeginDuration(std::string_view label) {
LogEvent("B", label);
}

void EndDuration() {
LogEvent("E");
}

private:
using Clock = std::chrono::steady_clock;

void LogEvent(std::string_view phase_type, std::optional<std::string_view> label = std::nullopt) {
const auto thread_id = std::this_thread::get_id();
const auto ts = std::chrono::duration_cast<std::chrono::microseconds>(Clock::now() - start_);

std::ostringstream event{};

event << "{";

if (label.has_value()) {
event << "\"name\": \"" << *label << "\", ";
}

event << "\"cat\": \"perf\", "
<< "\"ph\": \"" << phase_type << "\", "
<< "\"pid\": 0, "
<< "\"tid\": " << thread_id << ", "
<< "\"ts\": " << ts.count()
<< "}";

{
std::scoped_lock g{output_mutex_};

// add the delimiter only after writing the first event
if (insert_event_delimiter_) {
ostream_ << ",\n";
} else {
insert_event_delimiter_ = true;
}

ostream_ << event.str();
}
}

std::ofstream ostream_;
const Clock::time_point start_;
bool insert_event_delimiter_;

std::mutex output_mutex_;
};

std::string GetTraceFileName() {
constexpr const char* kTraceFileEnvironmentVariableName = "ORTGENAI_TRACE_FILE_PATH";
auto trace_file_name = GetEnv(kTraceFileEnvironmentVariableName);
if (trace_file_name.empty()) {
trace_file_name = "ortgenai_trace.log";
}
return trace_file_name;
}

} // namespace

#endif // defined(ORTGENAI_ENABLE_TRACING)

Tracer::Tracer() {
#if defined(ORTGENAI_ENABLE_TRACING)
const auto trace_file_name = GetTraceFileName();
sink_ = std::make_unique<FileTraceSink>(trace_file_name);
#endif
}

void Tracer::BeginDuration(std::string_view label) {
#if defined(ORTGENAI_ENABLE_TRACING)
sink_->BeginDuration(label);
#else
static_cast<void>(label);
#endif
}

void Tracer::EndDuration() {
#if defined(ORTGENAI_ENABLE_TRACING)
sink_->EndDuration();
#endif
}

Tracer& DefaultTracerInstance() {
static auto tracer = Tracer{};
return tracer;
}

} // namespace Generators
76 changes: 76 additions & 0 deletions src/tracing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Build with CMake option ENABLE_TRACING=ON to enable tracing.
// To avoid performance overhead, tracing is not enabled by default.

// When tracing is enabled, the trace data will be recorded to a file.
// The trace file path can be specified with the environment variable ORTGENAI_TRACE_FILE_PATH.
// The trace file can be viewed with Perfetto UI (https://ui.perfetto.dev/).

#pragma once

#include <memory>
#include <string_view>

namespace Generators {

// Trace consumer interface.
class TraceSink {
public:
virtual void BeginDuration(std::string_view label) = 0;
virtual void EndDuration() = 0;
virtual ~TraceSink() = default;
};

// Main tracing class.
class Tracer {
public:
Tracer();

// Begins a traced duration with the given label.
void BeginDuration(std::string_view label);

// Ends the traced duration from the most recent call to BeginDuration() in the same thread.
void EndDuration();

private:
Tracer(const Tracer&) = delete;
Tracer& operator=(const Tracer&) = delete;
Tracer(Tracer&&) = delete;
Tracer& operator=(Tracer&&) = delete;

#if defined(ORTGENAI_ENABLE_TRACING)
std::unique_ptr<TraceSink> sink_;
#endif
};

// Gets the default tracer instance.
Tracer& DefaultTracerInstance();

// Records a traced duration while in scope.
class DurationTrace {
public:
[[nodiscard]] DurationTrace(std::string_view label)
: DurationTrace{DefaultTracerInstance(), label} {
}

[[nodiscard]] DurationTrace(Tracer& tracer, std::string_view label)
: tracer_{tracer} {
tracer_.BeginDuration(label);
}

~DurationTrace() {
tracer_.EndDuration();
}

private:
DurationTrace(const DurationTrace&) = delete;
DurationTrace& operator=(const DurationTrace&) = delete;
DurationTrace(DurationTrace&&) = delete;
DurationTrace& operator=(DurationTrace&&) = delete;

Tracer& tracer_;
};

} // namespace Generators
Loading