diff --git a/.github/workflows/integration-tests-amd.yml b/.github/workflows/integration-tests-amd.yml index 718902463d01..21e51112d5c7 100644 --- a/.github/workflows/integration-tests-amd.yml +++ b/.github/workflows/integration-tests-amd.yml @@ -205,3 +205,115 @@ jobs: run: | rm -rf ~/.triton/cache rm -rf ~/.ccache + + proton-tests-amd-rocm712: + if: ${{ always() && github.repository == 'triton-lang/triton' }} + needs: integration-tests-amd + name: proton-tests-amd (gfx950-rocm712) + runs-on: ["amd-gfx950"] + timeout-minutes: 25 + env: + TRITON_BUILD_WITH_CCACHE: "true" + TRITON_BUILD_WITH_CLANG_LLD: "TRUE" + TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" + TRITON_DISABLE_LINE_INFO: 1 + PROTON_SKIP_PC_SAMPLING_TEST: 1 + PYTHON: "python3" + CCACHE_COMPRESS: "true" + PIP_BREAK_SYSTEM_PACKAGES: 1 + container: + image: rocm/vllm:rocm7.12.0_gfx950-dcgpu_ubuntu24.04_py3.12_pytorch_2.9.1_vllm_0.16.0 + options: >- + --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root + --env-file /etc/podinfo/gha-gpu-isolation-settings + --volume /home/runner/.triton:/github/home/.triton + --volume /triton-data:/triton-data + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + submodules: 'true' + - name: Compute cache keys + id: cache-key + run: | + llvm_file="cmake/llvm-hash.txt" + nvidia_file="cmake/nvidia-toolchain-version.json" + json_file="cmake/json-version.txt" + + if [[ ! -f "$llvm_file" || ! -f "$nvidia_file" || ! -f "$json_file" ]]; then + echo "Error: Required dependency files are missing." + exit 1 + fi + + echo "llvm=$(cat $llvm_file | cut -c 1-8)" >> $GITHUB_OUTPUT + echo "nvidia=$(sha256sum $nvidia_file | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT + echo "json=$(cat $json_file)" >> $GITHUB_OUTPUT + shell: bash + - name: Cache build dependencies + uses: actions/cache@v4 + with: + path: | + ~/.triton/llvm + ~/.triton/nvidia + ~/.triton/json + key: ${{ runner.os }}-${{ runner.arch }}-llvm-${{ steps.cache-key.outputs.llvm }}-nvidia-${{ steps.cache-key.outputs.nvidia }}-json-${{ steps.cache-key.outputs.json }} + - name: Install dependencies + run: | + for i in 1 2 3; do + apt-get -o Acquire::Retries=5 update && break + echo "apt-get update attempt $i failed, retrying in 10s..." + sleep 10 + done + apt-get install -y clang lld ccache + command -v clang && command -v lld && command -v ccache + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache + - name: Update compiler to Clang + run: | + export CC=/usr/bin/clang + export CXX=/usr/bin/clang++ + - name: Install Triton + run: | + echo "PATH is '$PATH'" + pip uninstall -y triton pytorch-triton-rocm + + ccache --zero-stats + pip install --cache-dir /triton-data/pip-cache -r python/requirements.txt + pip install --cache-dir /triton-data/pip-cache -r python/test-requirements.txt + make dev-install + - name: Print ccache stats + run: ccache --print-stats + - name: Run Proton tests + run: | + unset HIP_VISIBLE_DEVICES + unset ROCR_VISIBLE_DEVICES + ROCM_SDK_LIB="$(python3 -c 'import _rocm_sdk_core, os; print(os.path.join(os.path.dirname(_rocm_sdk_core.__file__), "lib"))')" + echo "ROCM_SDK_LIB=$ROCM_SDK_LIB" + for base in libamdhip64 librocprofiler-sdk librocprofiler-sdk-attach; do + if [ ! -e "$ROCM_SDK_LIB/${base}.so" ]; then + versioned="$(ls "$ROCM_SDK_LIB"/${base}.so.* 2>/dev/null | sort -V | head -1 || true)" + if [ -n "$versioned" ]; then + ln -sf "$(basename "$versioned")" "$ROCM_SDK_LIB/${base}.so" + echo "linked $ROCM_SDK_LIB/${base}.so -> $(basename "$versioned")" + fi + fi + done + export LD_LIBRARY_PATH="$ROCM_SDK_LIB:$LD_LIBRARY_PATH" + make test-proton + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + du -h -d 1 ~/.triton + + mkdir -p ~/.ccache + du -h -d 1 ~/.ccache + - name: Clean up caches + if: always() + run: | + rm -rf ~/.triton/cache + rm -rf ~/.ccache diff --git a/setup.py b/setup.py index 0acc84a39db3..5a534267b9be 100644 --- a/setup.py +++ b/setup.py @@ -260,10 +260,10 @@ def get_proton_cmake_args(self): if cupti_include_dir == "": cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include") cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir] - roctracer_include_dir = get_env_with_keys(["TRITON_ROCTRACER_INCLUDE_PATH"]) - if roctracer_include_dir == "": - roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include") - cmake_args += ["-DROCTRACER_INCLUDE_DIR=" + roctracer_include_dir] + rocm_include_dir = get_env_with_keys(["TRITON_ROCM_INCLUDE_PATH"]) + if rocm_include_dir == "": + rocm_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include") + cmake_args += ["-DROCM_INCLUDE_DIR=" + rocm_include_dir] return cmake_args def build_extension(self, ext): diff --git a/third_party/proton/CMakeLists.txt b/third_party/proton/CMakeLists.txt index abd90d689f4f..60835ff3c790 100644 --- a/third_party/proton/CMakeLists.txt +++ b/third_party/proton/CMakeLists.txt @@ -7,8 +7,8 @@ set(PROTON_COMMON_DIR "${CMAKE_CURRENT_SOURCE_DIR}/common") if(NOT CUPTI_INCLUDE_DIR) message(FATAL_ERROR "CUPTI include directory not defined") endif() -if(NOT ROCTRACER_INCLUDE_DIR) - message(FATAL_ERROR "ROCTRACER include directory not defined") +if(NOT ROCM_INCLUDE_DIR) + message(FATAL_ERROR "ROCM include directory not defined") endif() if(NOT JSON_INCLUDE_DIR) message(FATAL_ERROR "JSON include directory not defined") @@ -30,7 +30,7 @@ function(add_proton_library name) # Use system to skip warnings caused by legacy clang compilers target_include_directories(${name} SYSTEM PRIVATE - "${ROCTRACER_INCLUDE_DIR}" + "${ROCM_INCLUDE_DIR}" ) target_include_directories(${name} diff --git a/third_party/proton/csrc/include/Driver/GPU/RocprofApi.h b/third_party/proton/csrc/include/Driver/GPU/RocprofApi.h new file mode 100644 index 000000000000..07462fd0f231 --- /dev/null +++ b/third_party/proton/csrc/include/Driver/GPU/RocprofApi.h @@ -0,0 +1,88 @@ +#ifndef PROTON_DRIVER_GPU_ROCPROFILER_API_H_ +#define PROTON_DRIVER_GPU_ROCPROFILER_API_H_ + +#include "Driver/Dispatch.h" +#include "rocprofiler-sdk/agent.h" +#include "rocprofiler-sdk/buffer.h" +#include "rocprofiler-sdk/buffer_tracing.h" +#include "rocprofiler-sdk/callback_tracing.h" +#include "rocprofiler-sdk/fwd.h" +#include "rocprofiler-sdk/hip/api_args.h" +#include "rocprofiler-sdk/hip/runtime_api_id.h" +#include "rocprofiler-sdk/internal_threading.h" +#include "rocprofiler-sdk/registration.h" + +namespace proton { + +namespace rocprofiler { + +struct ExternLibRocprofiler : public ExternLibBase { + using RetType = rocprofiler_status_t; + static constexpr const char *name = "librocprofiler-sdk.so"; + static constexpr const char *symbolName = "rocprofiler_is_initialized"; + static constexpr const char *pathEnv = "TRITON_ROCPROFILER_SDK_LIB_PATH"; + static constexpr RetType success = ROCPROFILER_STATUS_SUCCESS; + static inline void *lib = nullptr; +}; + +template rocprofiler_status_t isInitialized(int *status); + +template +rocprofiler_status_t forceConfigure(rocprofiler_configure_func_t configureFunc); + +template +rocprofiler_status_t createContext(rocprofiler_context_id_t *context); + +template +rocprofiler_status_t destroyContext(rocprofiler_context_id_t context); + +template +rocprofiler_status_t startContext(rocprofiler_context_id_t context); + +template +rocprofiler_status_t stopContext(rocprofiler_context_id_t context); + +template +rocprofiler_status_t +createBuffer(rocprofiler_context_id_t context, size_t size, size_t watermark, + rocprofiler_buffer_policy_t policy, + rocprofiler_buffer_tracing_cb_t callback, void *userData, + rocprofiler_buffer_id_t *buffer); + +template +rocprofiler_status_t destroyBuffer(rocprofiler_buffer_id_t buffer); + +template +rocprofiler_status_t flushBuffer(rocprofiler_buffer_id_t buffer); + +template +rocprofiler_status_t configureBufferTracingService( + rocprofiler_context_id_t context, rocprofiler_buffer_tracing_kind_t kind, + const rocprofiler_tracing_operation_t *operations, size_t operationCount, + rocprofiler_buffer_id_t buffer); + +template +rocprofiler_status_t configureCallbackTracingService( + rocprofiler_context_id_t context, rocprofiler_callback_tracing_kind_t kind, + const rocprofiler_tracing_operation_t *operations, size_t operationCount, + rocprofiler_callback_tracing_cb_t callback, void *userData); + +template +rocprofiler_status_t +createCallbackThread(rocprofiler_callback_thread_t *thread); + +template +rocprofiler_status_t assignCallbackThread(rocprofiler_buffer_id_t buffer, + rocprofiler_callback_thread_t thread); + +template +rocprofiler_status_t +queryAvailableAgents(rocprofiler_agent_version_t version, + rocprofiler_query_available_agents_cb_t callback, + size_t agentSize, void *userData); + +} // namespace rocprofiler + +} // namespace proton + +#endif // PROTON_DRIVER_GPU_ROCPROFILER_API_H_ diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h index 77546556b9d5..c7eb1c19748d 100644 --- a/third_party/proton/csrc/include/Profiler/GPUProfiler.h +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -27,7 +27,6 @@ namespace detail { void flushDataPhasesImpl( const bool periodicFlushEnabled, const std::string &periodicFlushingFormat, - std::map &dataFlushedPhases, const std::map> &dataPhases, @@ -93,14 +92,12 @@ class GPUProfiler : public Profiler, } void flushDataPhases( - std::map &dataFlushedPhases, const std::map> &dataPhases, PendingGraphPool *pendingGraphPool) { detail::flushDataPhasesImpl(periodicFlushingEnabled, periodicFlushingFormat, - dataFlushedPhases, dataPhases, - pendingGraphPool); + dataPhases, pendingGraphPool); } // Profiler diff --git a/third_party/proton/csrc/include/Profiler/RocprofSDK/RocprofSDKProfiler.h b/third_party/proton/csrc/include/Profiler/RocprofSDK/RocprofSDKProfiler.h new file mode 100644 index 000000000000..25e252380761 --- /dev/null +++ b/third_party/proton/csrc/include/Profiler/RocprofSDK/RocprofSDKProfiler.h @@ -0,0 +1,22 @@ +#ifndef PROTON_PROFILER_ROCPROFSDK_PROFILER_H_ +#define PROTON_PROFILER_ROCPROFSDK_PROFILER_H_ + +#include "Profiler/GPUProfiler.h" + +namespace proton { + +class RocprofSDKProfiler : public GPUProfiler { +public: + RocprofSDKProfiler(); + virtual ~RocprofSDKProfiler(); + + struct RocprofSDKProfilerPimpl; + +private: + virtual void + doSetMode(const std::vector &modeAndOptions) override; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_ROCPROFSDK_PROFILER_H_ diff --git a/third_party/proton/csrc/lib/Driver/CMakeLists.txt b/third_party/proton/csrc/lib/Driver/CMakeLists.txt index 438f24f49e1c..9d85eafe5187 100644 --- a/third_party/proton/csrc/lib/Driver/CMakeLists.txt +++ b/third_party/proton/csrc/lib/Driver/CMakeLists.txt @@ -4,6 +4,7 @@ add_proton_library(ProtonDriver GPU/CuptiApi.cpp GPU/HipApi.cpp GPU/HsaApi.cpp - GPU/RoctracerApi.cpp GPU/NvtxApi.cpp + GPU/RoctracerApi.cpp + GPU/RocprofApi.cpp ) diff --git a/third_party/proton/csrc/lib/Driver/GPU/RocprofApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/RocprofApi.cpp new file mode 100644 index 000000000000..2f359f233883 --- /dev/null +++ b/third_party/proton/csrc/lib/Driver/GPU/RocprofApi.cpp @@ -0,0 +1,60 @@ +#include "Driver/GPU/RocprofApi.h" + +namespace proton { +namespace rocprofiler { + +DEFINE_DISPATCH(ExternLibRocprofiler, isInitialized, rocprofiler_is_initialized, + int *) + +DEFINE_DISPATCH(ExternLibRocprofiler, forceConfigure, + rocprofiler_force_configure, rocprofiler_configure_func_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, createContext, rocprofiler_create_context, + rocprofiler_context_id_t *) + +DEFINE_DISPATCH(ExternLibRocprofiler, destroyContext, + rocprofiler_destroy_context, rocprofiler_context_id_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, startContext, rocprofiler_start_context, + rocprofiler_context_id_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, stopContext, rocprofiler_stop_context, + rocprofiler_context_id_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, createBuffer, rocprofiler_create_buffer, + rocprofiler_context_id_t, size_t, size_t, + rocprofiler_buffer_policy_t, rocprofiler_buffer_tracing_cb_t, + void *, rocprofiler_buffer_id_t *) + +DEFINE_DISPATCH(ExternLibRocprofiler, destroyBuffer, rocprofiler_destroy_buffer, + rocprofiler_buffer_id_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, flushBuffer, rocprofiler_flush_buffer, + rocprofiler_buffer_id_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, configureBufferTracingService, + rocprofiler_configure_buffer_tracing_service, + rocprofiler_context_id_t, rocprofiler_buffer_tracing_kind_t, + const rocprofiler_tracing_operation_t *, size_t, + rocprofiler_buffer_id_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, configureCallbackTracingService, + rocprofiler_configure_callback_tracing_service, + rocprofiler_context_id_t, rocprofiler_callback_tracing_kind_t, + const rocprofiler_tracing_operation_t *, size_t, + rocprofiler_callback_tracing_cb_t, void *) + +DEFINE_DISPATCH(ExternLibRocprofiler, createCallbackThread, + rocprofiler_create_callback_thread, + rocprofiler_callback_thread_t *) + +DEFINE_DISPATCH(ExternLibRocprofiler, assignCallbackThread, + rocprofiler_assign_callback_thread, rocprofiler_buffer_id_t, + rocprofiler_callback_thread_t) + +DEFINE_DISPATCH(ExternLibRocprofiler, queryAvailableAgents, + rocprofiler_query_available_agents, rocprofiler_agent_version_t, + rocprofiler_query_available_agents_cb_t, size_t, void *) + +} // namespace rocprofiler +} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/CMakeLists.txt b/third_party/proton/csrc/lib/Profiler/CMakeLists.txt index 00dcaef97204..0212edd36670 100644 --- a/third_party/proton/csrc/lib/Profiler/CMakeLists.txt +++ b/third_party/proton/csrc/lib/Profiler/CMakeLists.txt @@ -5,6 +5,7 @@ add_proton_library(ProtonProfiler Cupti/CuptiPCSampling.cpp Cupti/CuptiProfiler.cpp RocTracer/RoctracerProfiler.cpp + RocprofSDK/RocprofSDKProfiler.cpp Instrumentation/InstrumentationProfiler.cpp Instrumentation/Metadata.cpp ) diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp index 39cc8e990f86..046d0a1b7927 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -432,7 +432,6 @@ void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, size_t validSize) { CuptiProfiler &profiler = threadState.profiler; uint32_t maxCorrelationId = 0; - static thread_local std::map dataFlushedPhases; std::map> dataPhases; CUptiResult status; CUpti_Activity *activity = nullptr; @@ -456,8 +455,7 @@ void CuptiProfiler::CuptiProfilerPimpl::completeBuffer(CUcontext ctx, std::free(buffer); profiler.correlation.complete(maxCorrelationId); - profiler.flushDataPhases(dataFlushedPhases, dataPhases, - profiler.pendingGraphPool.get()); + profiler.flushDataPhases(dataPhases, profiler.pendingGraphPool.get()); } void CuptiProfiler::CuptiProfilerPimpl::handleGraphResourceCallbacks( diff --git a/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp b/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp index abb421e27be7..fa489ec4f4fe 100644 --- a/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/GPUProfiler.cpp @@ -21,7 +21,6 @@ struct FlushRange { std::pair, std::set> computeFlushRangesAndPeekPhases( - std::map &dataFlushedPhases, const std::map> &dataPhases, @@ -35,17 +34,15 @@ computeFlushRangesAndPeekPhases( continue; } - auto flushedPhaseIt = dataFlushedPhases.find(data); // phase.second at maximum is the current phase, which cannot be a // "complete" phase yet. So we flush up to phase.second - 1. const size_t endPhaseToFlush = phase.second - 1; size_t minPhaseToFlush = 0; - if (flushedPhaseIt == dataFlushedPhases.end() || - flushedPhaseIt->second == Data::kNoCompletePhase) { + const auto flushedPhase = data->getPhaseInfo().completeUpTo; + if (flushedPhase == Data::kNoCompletePhase) { minPhaseToFlush = 0; } else { - const auto flushedPhase = flushedPhaseIt->second; if (endPhaseToFlush <= flushedPhase) { continue; } @@ -213,15 +210,14 @@ void updateDataPhases(std::map> &dataPhases, void flushDataPhasesImpl( const bool periodicFlushEnabled, const std::string &periodicFlushingFormat, - std::map &dataFlushedPhases, const std::map> &dataPhases, PendingGraphPool *pendingGraphPool) { static const bool timingEnabled = getBoolEnv("PROTON_DATA_FLUSH_TIMING", false); - auto [flushRanges, phasesToPeek] = computeFlushRangesAndPeekPhases( - dataFlushedPhases, dataPhases, pendingGraphPool != nullptr); + auto [flushRanges, phasesToPeek] = + computeFlushRangesAndPeekPhases(dataPhases, pendingGraphPool != nullptr); if (pendingGraphPool) { using Clock = std::chrono::steady_clock; uint64_t totalPeekUs = 0; @@ -252,7 +248,6 @@ void flushDataPhasesImpl( auto *data = range.data; const size_t minPhaseToFlush = range.minPhaseToFlush; const size_t maxPhaseToFlush = range.maxPhaseToFlush; - dataFlushedPhases[data] = maxPhaseToFlush; data->completePhase(maxPhaseToFlush); if (!periodicFlushEnabled) diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index 6fe9aaf44569..68614fc3f8c0 100644 --- a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -418,7 +418,6 @@ void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( profiler.pImpl.get()); auto &correlation = profiler.correlation; - static thread_local std::map dataFlushedPhases; const roctracer_record_t *record = reinterpret_cast(begin); const roctracer_record_t *endRecord = @@ -447,8 +446,7 @@ void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( roctracer::getNextRecord(record, &record); } correlation.complete(maxCorrelationId); - profiler.flushDataPhases(dataFlushedPhases, dataPhases, - profiler.pendingGraphPool.get()); + profiler.flushDataPhases(dataPhases, profiler.pendingGraphPool.get()); } void RoctracerProfiler::RoctracerProfilerPimpl::doStart() { diff --git a/third_party/proton/csrc/lib/Profiler/RocprofSDK/RocprofSDKProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocprofSDK/RocprofSDKProfiler.cpp new file mode 100644 index 000000000000..5469ce8b12ce --- /dev/null +++ b/third_party/proton/csrc/lib/Profiler/RocprofSDK/RocprofSDKProfiler.cpp @@ -0,0 +1,1006 @@ +#include "Profiler/RocprofSDK/RocprofSDKProfiler.h" + +#include "Context/Context.h" +#include "Data/Metric.h" +#include "Driver/Dispatch.h" +#include "Driver/GPU/HipApi.h" +#include "Driver/GPU/RocprofApi.h" +#include "Driver/GPU/RoctxTypes.h" +#include "Profiler/GPUProfiler.h" +#include "Runtime/HipRuntime.h" +#include "Utility/Env.h" +#include "Utility/Map.h" +#include "Utility/Singleton.h" + +#include "hip/hip_runtime_api.h" +#include "rocprofiler-sdk/agent.h" +#include "rocprofiler-sdk/buffer_tracing.h" +#include "rocprofiler-sdk/callback_tracing.h" +#include "rocprofiler-sdk/hip/api_args.h" +#include "rocprofiler-sdk/hip/runtime_api_id.h" +#include "rocprofiler-sdk/marker/api_id.h" +#include "rocprofiler-sdk/registration.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace proton { + +template <> +thread_local GPUProfiler::ThreadState + GPUProfiler::threadState( + RocprofSDKProfiler::instance()); + +namespace { + +constexpr size_t BufferSize = 64 * 1024 * 1024; +constexpr const char *UnknownKernelName = ""; + +struct RocprofSDKProfilerPimpl; + +// ---- SDK runtime state (singleton, outlives any profiler instance) ---- + +struct RocprofilerRuntimeState { + std::mutex mutex; + rocprofiler_context_id_t codeObjectContext{}; + rocprofiler_context_id_t profilingContext{}; + rocprofiler_buffer_id_t kernelBuffer{}; + rocprofiler_callback_thread_t callbackThread{}; + rocprofiler_client_finalize_t finalizeFunc = nullptr; + rocprofiler_client_id_t *clientId{nullptr}; + bool configured{false}; + bool codeObjectStarted{false}; + bool profilingStarted{false}; + std::atomic nvtxEnabled{false}; + RocprofSDKProfiler::RocprofSDKProfilerPimpl *pimpl{nullptr}; +}; + +RocprofilerRuntimeState &getRuntimeState() { + static RocprofilerRuntimeState state; + return state; +} + +using RoctxTracerCallbackFn = int (*)(uint32_t domain, uint32_t operationId, + void *data); +using RoctxRegisterTracerCallbackFn = void (*)(RoctxTracerCallbackFn); + +// registerRoctxCallback is defined after the Pimpl class (needs access to +// the static roctxCallback member). +void registerRoctxCallback(bool enable); + +// ---- Agent (GPU) ID mapping ---- + +class AgentIdMapper : public Singleton { +public: + AgentIdMapper() = default; + + void initialize() { + std::call_once(initializeFlag, [this]() { + rocprofiler::queryAvailableAgents( + ROCPROFILER_AGENT_INFO_VERSION_0, &AgentIdMapper::callback, + sizeof(rocprofiler_agent_t), this); + }); + } + + uint32_t map(uint64_t agentHandle) const { + auto it = agentToDevice.find(agentHandle); + if (it != agentToDevice.end()) + return it->second; + return 0; + } + +private: + static rocprofiler_status_t callback(rocprofiler_agent_version_t version, + const void **agents, size_t count, + void *userData) { + auto *self = static_cast(userData); + if (version != ROCPROFILER_AGENT_INFO_VERSION_0) + return ROCPROFILER_STATUS_ERROR_INVALID_ARGUMENT; + auto agentList = + reinterpret_cast(agents); + self->agentToDevice.clear(); + for (size_t i = 0; i < count; ++i) { + const auto *agent = agentList[i]; + if (agent->type == ROCPROFILER_AGENT_TYPE_GPU && + agent->runtime_visibility.hip) { + auto ordinal = getHipOrdinal(*agent); + if (ordinal) + self->agentToDevice[agent->id.handle] = + static_cast(*ordinal); + } + } + return ROCPROFILER_STATUS_SUCCESS; + } + + static uint64_t getUuidValue(const rocprofiler_uuid_t &uuid) { + uint64_t value = 0; + static_assert(sizeof(value) <= sizeof(uuid.bytes)); + std::memcpy(&value, uuid.bytes, sizeof(value)); + return value; + } + + static bool isDecimalOrdinal(const std::string &token) { + return !token.empty() && + std::all_of(token.begin(), token.end(), + [](unsigned char c) { return std::isdigit(c); }); + } + + // Bridge rocprofiler-sdk's ROCR agent identity to the HIP device ordinals + // used by Proton metrics. Visibility filters are layered: for example, + // ROCR_VISIBLE_DEVICES can first create a reordered physical-agent list, and + // HIP_VISIBLE_DEVICES then indexes into that ROCR-visible list. HIP reports + // the selected device as ordinal 0, while rocprofiler records still identify + // the underlying agent. + static std::optional + getVisibleIndex(const std::string &envName, int32_t ordinal, + const rocprofiler_uuid_t &uuid) { + auto env = getStrEnv(envName); + if (env.empty()) + return std::nullopt; + + constexpr const char *UuidPrefix = "GPU-"; + auto uuidValue = getUuidValue(uuid); + int32_t index = 0; + size_t tokenBegin = env.find_first_not_of(", "); + while (tokenBegin != std::string::npos) { + auto tokenEnd = env.find_first_of(", ", tokenBegin); + auto token = env.substr(tokenBegin, tokenEnd - tokenBegin); + if (isDecimalOrdinal(token)) { + if (std::stoll(token) == ordinal) + return index; + } else if (token.rfind(UuidPrefix, 0) == 0 && + token.size() > std::strlen(UuidPrefix)) { + auto tokenUuid = + std::strtoull(token.c_str() + std::strlen(UuidPrefix), nullptr, 16); + if (tokenUuid == uuidValue) + return index; + } + ++index; + if (tokenEnd == std::string::npos) + break; + tokenBegin = env.find_first_not_of(", ", tokenEnd); + } + return -1; + } + + static std::optional + getHipOrdinal(const rocprofiler_agent_t &agent) { + auto rocrIndex = agent.logical_node_type_id; + auto rocrVisible = getVisibleIndex("ROCR_VISIBLE_DEVICES", + agent.logical_node_type_id, agent.uuid); + if (rocrVisible) { + if (*rocrVisible < 0) + return std::nullopt; + rocrIndex = *rocrVisible; + } + + auto hipVisible = + getVisibleIndex("HIP_VISIBLE_DEVICES", rocrIndex, agent.uuid); + if (!hipVisible) + hipVisible = + getVisibleIndex("CUDA_VISIBLE_DEVICES", rocrIndex, agent.uuid); + if (!hipVisible) + hipVisible = getVisibleIndex("GPU_DEVICE_ORDINAL", rocrIndex, agent.uuid); + if (hipVisible) { + if (*hipVisible < 0) + return std::nullopt; + return *hipVisible; + } + + return rocrVisible ? std::optional{rocrIndex} + : std::optional{agent.logical_node_type_id}; + } + + std::once_flag initializeFlag; + std::unordered_map agentToDevice; +}; + +// ---- Metric conversion ---- + +std::unique_ptr convertDispatchToMetric( + const rocprofiler_buffer_tracing_kernel_dispatch_record_t *record, + uint64_t streamId) { + if (record->start_timestamp >= record->end_timestamp) + return nullptr; + auto deviceId = static_cast( + AgentIdMapper::instance().map(record->dispatch_info.agent_id.handle)); + return std::make_unique( + static_cast(record->start_timestamp), + static_cast(record->end_timestamp), 1, deviceId, + static_cast(DeviceType::HIP), streamId); +} + +// ---- Kernel name resolution at API ENTER time ---- + +const char *resolveKernelNameAtEnter( + rocprofiler_tracing_operation_t op, + const rocprofiler_callback_tracing_hip_api_data_t *payload) { + switch (op) { + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchKernel: + return hip::getKernelNameRefByPtr( + payload->args.hipLaunchKernel.function_address, + payload->args.hipLaunchKernel.stream); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchKernel: + return hip::getKernelNameRefByPtr( + payload->args.hipExtLaunchKernel.function_address, + payload->args.hipExtLaunchKernel.stream); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernel: + return hip::getKernelNameRefByPtr( + payload->args.hipLaunchCooperativeKernel.func, + payload->args.hipLaunchCooperativeKernel.stream); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchKernel: + return hip::getKernelNameRef(payload->args.hipModuleLaunchKernel.func); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtModuleLaunchKernel: + return hip::getKernelNameRef(payload->args.hipExtModuleLaunchKernel.func); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipHccModuleLaunchKernel: + return hip::getKernelNameRef(payload->args.hipHccModuleLaunchKernel.func); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernel: + return hip::getKernelNameRef( + payload->args.hipModuleLaunchCooperativeKernel.func); + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchMultiKernelMultiDevice: { + const auto *params = + payload->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList; + if (params && + payload->args.hipExtLaunchMultiKernelMultiDevice.numDevices > 0) + return hip::getKernelNameRefByPtr(params->func, params->stream); + return nullptr; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernelMultiDevice: { + const auto *params = + payload->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList; + if (params && + payload->args.hipLaunchCooperativeKernelMultiDevice.numDevices > 0) + return hip::getKernelNameRefByPtr(params->func, params->stream); + return nullptr; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: { + const auto *params = + payload->args.hipModuleLaunchCooperativeKernelMultiDevice + .launchParamsList; + if (params && + payload->args.hipModuleLaunchCooperativeKernelMultiDevice.numDevices > + 0) + return hip::getKernelNameRef(params->function); + return nullptr; + } + default: + return nullptr; + } +} + +// ---- HIP stream extraction at API ENTER time ---- + +uint64_t +extractStreamId(rocprofiler_tracing_operation_t op, + const rocprofiler_callback_tracing_hip_api_data_t *payload) { + hipStream_t stream = nullptr; + switch (op) { + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchKernel: + stream = payload->args.hipLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchKernel: + stream = payload->args.hipExtLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernel: + stream = payload->args.hipLaunchCooperativeKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchKernel: + stream = payload->args.hipModuleLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtModuleLaunchKernel: + stream = payload->args.hipExtModuleLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipHccModuleLaunchKernel: + stream = payload->args.hipHccModuleLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernel: + stream = payload->args.hipModuleLaunchCooperativeKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphLaunch: + stream = payload->args.hipGraphLaunch.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchMultiKernelMultiDevice: { + const auto *p = + payload->args.hipExtLaunchMultiKernelMultiDevice.launchParamsList; + if (p && payload->args.hipExtLaunchMultiKernelMultiDevice.numDevices > 0) + stream = p->stream; + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernelMultiDevice: { + const auto *p = + payload->args.hipLaunchCooperativeKernelMultiDevice.launchParamsList; + if (p && payload->args.hipLaunchCooperativeKernelMultiDevice.numDevices > 0) + stream = p->stream; + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: { + const auto *p = payload->args.hipModuleLaunchCooperativeKernelMultiDevice + .launchParamsList; + if (p && + payload->args.hipModuleLaunchCooperativeKernelMultiDevice.numDevices > + 0) + stream = p->hStream; + break; + } + default: + break; + } + return reinterpret_cast(stream); +} + +// ---- Operation classification ---- + +bool isKernelLaunchOperation(rocprofiler_tracing_operation_t op) { + switch (op) { + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchMultiKernelMultiDevice: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtModuleLaunchKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipHccModuleLaunchKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernelMultiDevice: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphLaunch: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernel: + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernelMultiDevice: + return true; + default: + return false; + } +} + +// ---- Kernel dispatch processing (matches main's GPUProfiler interface) ---- + +void processKernelRecord( + RocprofSDKProfiler &profiler, + RocprofSDKProfiler::CorrIdToExternIdMap &corrIdToExternId, + RocprofSDKProfiler::ExternIdToStateMap &externIdToState, + ThreadSafeMap> + &corrIdToIsHipGraph, + std::map> &dataPhases, + const std::string &kernelName, + const rocprofiler_buffer_tracing_kernel_dispatch_record_t *record, + uint64_t streamId) { + auto externId = Scope::DummyScopeId; + bool hasCorrelation = + corrIdToExternId.withRead(record->correlation_id.internal, + [&](const size_t &value) { externId = value; }); + + if (!hasCorrelation) + return; + + if (externId == Scope::DummyScopeId) + return; + + bool isGraph = corrIdToIsHipGraph.contain(record->correlation_id.internal); + auto &state = externIdToState[externId]; + + if (!isGraph) { + for (auto [data, entry] : state.dataToEntry) { + if (auto metric = convertDispatchToMetric(record, streamId)) { + if (state.isMissingName) { + auto childEntry = + data->addOp(entry.phase, entry.id, {Context(kernelName)}); + childEntry.upsertMetric(std::move(metric)); + } else { + entry.upsertMetric(std::move(metric)); + } + detail::updateDataPhases(dataPhases, data, entry.phase); + } + } + } else { + for (auto [data, entry] : state.dataToEntry) { + if (auto metric = convertDispatchToMetric(record, streamId)) { + auto childEntry = + data->addOp(entry.phase, entry.id, {Context(kernelName)}); + childEntry.upsertMetric(std::move(metric)); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + } + } + + --state.numNodes; + if (state.numNodes == 0) { + corrIdToExternId.erase(record->correlation_id.internal); + corrIdToIsHipGraph.erase(record->correlation_id.internal); + externIdToState.erase(externId); + } +} + +} // namespace + +// ---- Pimpl ---- + +struct RocprofSDKProfiler::RocprofSDKProfilerPimpl + : public GPUProfiler::GPUProfilerPimplInterface { + RocprofSDKProfilerPimpl(RocprofSDKProfiler &profiler) + : GPUProfiler::GPUProfilerPimplInterface(profiler) { + auto runtime = &HipRuntime::instance(); + profiler.metricBuffer = + std::make_unique(1024 * 1024 * 64, runtime); + } + virtual ~RocprofSDKProfilerPimpl() = default; + + void doStart() override; + void doFlush() override; + void doStop() override; + + static void hipRuntimeCallback(rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t *userData, void *arg); + static void markerCallback(rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t *userData, void *arg); + static void roctxCallback(uint32_t operationId, void *data); + static void codeObjectCallback(rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t *userData, void *arg); + static void kernelBufferCallback(rocprofiler_context_id_t context, + rocprofiler_buffer_id_t buffer, + rocprofiler_record_header_t **headers, + size_t numHeaders, void *userData, + uint64_t dropCount); + + using KernelNameMap = + ThreadSafeMap>; + + std::string getKernelName(uint64_t kernelId) { + std::string name; + if (!kernelNames.withRead(kernelId, + [&](const std::string &v) { name = v; })) + return UnknownKernelName; + // AMDGPU ELF objects append ".kd" (kernel descriptor) to symbol names. + // Strip it so user-visible kernel names match the source. + const std::string suffix = ".kd"; + if (name.size() > suffix.size() && + name.compare(name.size() - suffix.size(), suffix.size(), suffix) == 0) + name.resize(name.size() - suffix.size()); + return name; + } + + void setKernelName(uint64_t kernelId, const char *name) { + if (name == nullptr) + return; + kernelNames[kernelId] = std::string(name); + } + + ThreadSafeMap> + corrIdToIsHipGraph; + + ThreadSafeMap> + graphExecToGraph; + + ThreadSafeMap> + graphToNumInstances; + + ThreadSafeMap> + streamToCaptureCount; + + ThreadSafeMap> + streamToCapture; + + // Fast check: non-zero when any stream is being captured. Avoids acquiring + // a shared_mutex on every kernel launch EXIT just to find an empty map. + std::atomic activeCaptureCount{0}; + + KernelNameMap kernelNames; + + // correlation_id → HIP stream pointer, captured at hipLaunchKernel ENTER. + // Used to distinguish streams in trace output when the SDK's queue_id + // maps multiple HIP streams to the same underlying HSA queue. + ThreadSafeMap> + corrIdToStreamId; +}; + +// ---- HIP Runtime API callback (correlation tracking) ---- + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::hipRuntimeCallback( + rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t *userData, void *arg) { + if (record.kind != ROCPROFILER_CALLBACK_TRACING_HIP_RUNTIME_API) + return; + + auto operation = + static_cast(record.operation); + bool isKernelOp = isKernelLaunchOperation(operation); + auto &profiler = threadState.profiler; + auto *impl = static_cast(profiler.pImpl.get()); + auto *payload = static_cast( + record.payload); + + if (record.phase == ROCPROFILER_CALLBACK_PHASE_ENTER) { + if (!isKernelOp) + return; + + const char *resolvedName = resolveKernelNameAtEnter(operation, payload); + threadState.enterOp( + Scope(resolvedName ? std::string(resolvedName) : std::string())); + auto &dataToEntry = threadState.dataToEntry; + size_t numInstances = 1; + if (operation == ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphLaunch) { + impl->corrIdToIsHipGraph[record.correlation_id.internal] = true; + numInstances = std::numeric_limits::max(); + bool foundGraph = false; + auto graphExec = payload->args.hipGraphLaunch.graphExec; + if (impl->graphExecToGraph.contain(graphExec)) { + auto graph = impl->graphExecToGraph[graphExec]; + if (impl->graphToNumInstances.contain(graph)) { + numInstances = impl->graphToNumInstances[graph]; + foundGraph = true; + } + } + if (!foundGraph) { + std::cerr + << "[PROTON] Cannot find graph and it may cause a memory leak." + "To avoid this problem, please start profiling before the " + "graph is created." + << std::endl; + } + } + auto &scope = threadState.scopeStack.back(); + auto isMissingName = scope.name.empty(); + profiler.correlation.correlate(record.correlation_id.internal, + scope.scopeId, numInstances, isMissingName, + dataToEntry); + impl->corrIdToStreamId[record.correlation_id.internal] = + extractStreamId(operation, payload); + return; + } + + if (record.phase != ROCPROFILER_CALLBACK_PHASE_EXIT) + return; + + switch (operation) { + case ROCPROFILER_HIP_RUNTIME_API_ID_hipStreamBeginCapture: { + auto stream = payload->args.hipStreamBeginCapture.stream; + impl->streamToCaptureCount[stream] = 0; + impl->streamToCapture[stream] = true; + impl->activeCaptureCount.fetch_add(1, std::memory_order_release); + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipStreamEndCapture: { + auto stream = payload->args.hipStreamEndCapture.stream; + auto graph = *(payload->args.hipStreamEndCapture.pGraph); + uint32_t captured = impl->streamToCaptureCount.contain(stream) + ? impl->streamToCaptureCount[stream] + : 0; + impl->graphToNumInstances[graph] = captured; + impl->streamToCapture.erase(stream); + impl->streamToCaptureCount.erase(stream); + impl->activeCaptureCount.fetch_sub(1, std::memory_order_release); + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphInstantiateWithFlags: { + auto graph = payload->args.hipGraphInstantiateWithFlags.graph; + auto graphExec = *(payload->args.hipGraphInstantiateWithFlags.pGraphExec); + impl->graphExecToGraph[graphExec] = graph; + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphInstantiate: { + auto graph = payload->args.hipGraphInstantiate.graph; + auto graphExec = *(payload->args.hipGraphInstantiate.pGraphExec); + impl->graphExecToGraph[graphExec] = graph; + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphExecDestroy: { + auto graphExec = payload->args.hipGraphExecDestroy.graphExec; + impl->graphExecToGraph.erase(graphExec); + break; + } + case ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphDestroy: { + auto graph = payload->args.hipGraphDestroy.graph; + impl->graphToNumInstances.erase(graph); + break; + } + default: + break; + } + + // Count kernel launches during graph capture. The atomic fast-check avoids + // acquiring the shared_mutex on streamToCapture for every kernel launch + // when no capture is active (the overwhelmingly common case). + if (isKernelOp && + impl->activeCaptureCount.load(std::memory_order_acquire) > 0) { + hipStream_t stream = nullptr; + switch (operation) { + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchKernel: + stream = payload->args.hipLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchKernel: + stream = payload->args.hipExtLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernel: + stream = payload->args.hipLaunchCooperativeKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchKernel: + stream = payload->args.hipModuleLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernel: + stream = payload->args.hipModuleLaunchCooperativeKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipExtModuleLaunchKernel: + stream = payload->args.hipExtModuleLaunchKernel.stream; + break; + case ROCPROFILER_HIP_RUNTIME_API_ID_hipHccModuleLaunchKernel: + stream = payload->args.hipHccModuleLaunchKernel.stream; + break; + default: + break; + } + if (stream && impl->streamToCapture.contain(stream)) + impl->streamToCaptureCount[stream]++; + } + + if (isKernelOp) { + threadState.exitOp(); + profiler.correlation.submit(record.correlation_id.internal); + } +} + +// ---- ROCTx marker callback via rocprofiler-sdk ---- +// +// Prefer rocprofiler-sdk marker tracing for ROCTx events. Some PyTorch/ROCm +// environments load the legacy libroctx64 provider for torch.cuda.nvtx calls +// without making its symbols globally visible, so MARKER_CORE_API alone does +// not see those ranges. registerRoctxCallback below attaches to the loaded +// legacy provider when present. + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::markerCallback( + rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t *userData, void *arg) { + if (record.kind != ROCPROFILER_CALLBACK_TRACING_MARKER_CORE_API) + return; + if (record.phase != ROCPROFILER_CALLBACK_PHASE_ENTER) + return; + if (!getRuntimeState().nvtxEnabled.load(std::memory_order_relaxed)) + return; + + auto op = static_cast(record.operation); + if (op == ROCPROFILER_MARKER_CORE_API_ID_roctxRangePushA) { + auto *payload = + static_cast( + record.payload); + threadState.enterScope(payload->args.roctxRangePushA.message); + } else if (op == ROCPROFILER_MARKER_CORE_API_ID_roctxRangePop) { + threadState.exitScope(); + } +} + +// Legacy libroctx64.so callback — kept as fallback for environments where +// librocprofiler-sdk-roctx.so is not loaded (e.g. bare ROCm without TheRock). +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::roctxCallback( + uint32_t operationId, void *data) { + auto *apiData = static_cast(data); + if (operationId == ROCTX_API_ID_roctxRangePushA) { + threadState.enterScope(apiData->args.roctxRangePushA.message); + } else if (operationId == ROCTX_API_ID_roctxRangePop) { + threadState.exitScope(); + } +} + +namespace { +int roctxTracerCallback(uint32_t /*domain*/, uint32_t operationId, void *data) { + RocprofSDKProfiler::RocprofSDKProfilerPimpl::roctxCallback(operationId, data); + return 0; +} + +void registerRoctxCallback(bool enable) { + // torch.cuda.nvtx may route through a locally loaded libroctx64.so. In that + // case dlsym(RTLD_DEFAULT, "roctxRegisterTracerCallback") does not find the + // callback registration entry point, but resolving it from the library handle + // does. + void *roctxLib = dlopen("libroctx64.so", RTLD_NOLOAD | RTLD_NOW); + if (!roctxLib) + return; + auto *fn = reinterpret_cast( + dlsym(roctxLib, "roctxRegisterTracerCallback")); + dlclose(roctxLib); + if (!fn) + return; + fn(enable ? &roctxTracerCallback : nullptr); +} +} // namespace + +// ---- Code object callback (kernel_id -> name mapping) ---- + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::codeObjectCallback( + rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t *userData, void *arg) { + if (record.kind != ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT || + record.operation != + ROCPROFILER_CODE_OBJECT_DEVICE_KERNEL_SYMBOL_REGISTER || + record.phase != ROCPROFILER_CALLBACK_PHASE_LOAD) { + return; + } + auto *impl = static_cast(arg); + if (!impl) + return; + auto *payload = static_cast< + rocprofiler_callback_tracing_code_object_kernel_symbol_register_data_t *>( + record.payload); + impl->setKernelName(payload->kernel_id, payload->kernel_name); +} + +// ---- Kernel dispatch buffer callback ---- + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::kernelBufferCallback( + rocprofiler_context_id_t context, rocprofiler_buffer_id_t buffer, + rocprofiler_record_header_t **headers, size_t numHeaders, void *userData, + uint64_t dropCount) { + if (dropCount > 0) { + std::cerr << "[PROTON] ROCProfiler-SDK dropped " << dropCount + << " kernel dispatch records" << std::endl; + } + auto &profiler = threadState.profiler; + auto *impl = static_cast(profiler.pImpl.get()); + auto &correlation = profiler.correlation; + + uint64_t maxCorrelationId = 0; + std::map> dataPhases; + + for (size_t i = 0; i < numHeaders; ++i) { + auto *header = headers[i]; + if (header->category != ROCPROFILER_BUFFER_CATEGORY_TRACING || + header->kind != ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH) { + continue; + } + auto *record = + static_cast( + header->payload); + maxCorrelationId = + std::max(maxCorrelationId, record->correlation_id.internal); + auto kernelName = impl->getKernelName(record->dispatch_info.kernel_id); + uint64_t streamId = + static_cast(record->dispatch_info.queue_id.handle); + impl->corrIdToStreamId.withRead( + record->correlation_id.internal, + [&](const uint64_t &sid) { streamId = sid; }); + processKernelRecord(profiler, correlation.corrIdToExternId, + correlation.externIdToState, impl->corrIdToIsHipGraph, + dataPhases, kernelName, record, streamId); + impl->corrIdToStreamId.erase(record->correlation_id.internal); + } + profiler.flushDataPhases(dataPhases, profiler.pendingGraphPool.get()); + if (maxCorrelationId > 0) { + correlation.complete(maxCorrelationId); + } +} + +// ---- SDK tool init / fini (called by rocprofiler_force_configure) ---- + +namespace { + +int protonToolInit(rocprofiler_client_finalize_t finiFunc, void *toolData) { + auto *state = static_cast(toolData); + state->finalizeFunc = finiFunc; + + // Context 1: lightweight, always-active context for code object tracking. + // Captures kernel_id -> name mappings as kernels are compiled. + rocprofiler::createContext(&state->codeObjectContext); + + const rocprofiler_tracing_operation_t codeObjectOps[] = { + ROCPROFILER_CODE_OBJECT_DEVICE_KERNEL_SYMBOL_REGISTER}; + rocprofiler::configureCallbackTracingService( + state->codeObjectContext, ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT, + codeObjectOps, 1, + &RocprofSDKProfiler::RocprofSDKProfilerPimpl::codeObjectCallback, + static_cast(state->pimpl)); + + // Context 2: on-demand profiling context for HIP callback tracing and + // kernel dispatch buffer tracing. Started/stopped in doStart()/doStop(). + // Registering BUFFER_TRACING_KERNEL_DISPATCH here causes + // enable_queue_intercept() to install HSA queue hooks at force_configure + // time, even though the context is not yet active. + rocprofiler::createContext(&state->profilingContext); + + // Subscribe only to the HIP operations Proton needs: kernel launches, + // graph capture/instantiate/destroy. Passing nullptr/0 would subscribe to + // all ~519 HIP runtime APIs, causing the SDK to construct correlation IDs + // and invoke our callback for every hipMalloc, hipMemcpy, etc. + constexpr rocprofiler_tracing_operation_t kTracedHipOps[] = { + // Kernel launches (ENTER: correlation tracking, EXIT: capture counting) + ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipExtLaunchMultiKernelMultiDevice, + ROCPROFILER_HIP_RUNTIME_API_ID_hipExtModuleLaunchKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipHccModuleLaunchKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipLaunchCooperativeKernelMultiDevice, + ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernel, + ROCPROFILER_HIP_RUNTIME_API_ID_hipModuleLaunchCooperativeKernelMultiDevice, + ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphLaunch, + // Graph capture (EXIT only) + ROCPROFILER_HIP_RUNTIME_API_ID_hipStreamBeginCapture, + ROCPROFILER_HIP_RUNTIME_API_ID_hipStreamEndCapture, + // Graph instantiate (EXIT only) + ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphInstantiate, + ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphInstantiateWithFlags, + // Graph cleanup (EXIT only) + ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphExecDestroy, + ROCPROFILER_HIP_RUNTIME_API_ID_hipGraphDestroy, + }; + + rocprofiler::configureCallbackTracingService( + state->profilingContext, ROCPROFILER_CALLBACK_TRACING_HIP_RUNTIME_API, + kTracedHipOps, std::size(kTracedHipOps), + &RocprofSDKProfiler::RocprofSDKProfilerPimpl::hipRuntimeCallback, + nullptr); + + // Marker tracing: always configure MARKER_CORE_API so we intercept roctx + // calls that go through librocprofiler-sdk-roctx.so (TheRock/torch + // environments where the SDK's roctx interposes the global symbol). + // This is configured unconditionally because force_configure may run before + // torch loads the SDK's roctx library, and we can't add tracing services + // after startContext. If the SDK's roctx isn't loaded, these callbacks + // simply never fire. The legacy libroctx64.so callback registration in + // doStart()/doStop() handles environments where only libroctx64.so is used. + { + constexpr rocprofiler_tracing_operation_t kMarkerOps[] = { + ROCPROFILER_MARKER_CORE_API_ID_roctxRangePushA, + ROCPROFILER_MARKER_CORE_API_ID_roctxRangePop, + }; + rocprofiler::configureCallbackTracingService( + state->profilingContext, ROCPROFILER_CALLBACK_TRACING_MARKER_CORE_API, + kMarkerOps, std::size(kMarkerOps), + &RocprofSDKProfiler::RocprofSDKProfilerPimpl::markerCallback, nullptr); + } + + // Flush the buffer when it reaches 87.5% capacity, leaving headroom for + // in-flight records while the callback drains the buffer. + size_t watermark = BufferSize - (BufferSize / 8); + rocprofiler::createBuffer( + state->profilingContext, BufferSize, watermark, + ROCPROFILER_BUFFER_POLICY_LOSSLESS, + &RocprofSDKProfiler::RocprofSDKProfilerPimpl::kernelBufferCallback, + nullptr, &state->kernelBuffer); + + rocprofiler::configureBufferTracingService( + state->profilingContext, ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH, + nullptr, 0, state->kernelBuffer); + + rocprofiler::createCallbackThread(&state->callbackThread); + rocprofiler::assignCallbackThread(state->kernelBuffer, + state->callbackThread); + + AgentIdMapper::instance().initialize(); + + // Start the code object context now so the upcoming + // invoke_register_propagation() replay of already-loaded code objects + // triggers our callback while it's active. + rocprofiler::startContext(state->codeObjectContext); + state->codeObjectStarted = true; + + state->configured = true; + return 0; +} + +void protonToolFini(void *toolData) { + auto *state = static_cast(toolData); + { + std::lock_guard lock(state->mutex); + if (state->profilingStarted) { + rocprofiler::stopContext(state->profilingContext); + state->profilingStarted = false; + } + if (state->codeObjectStarted) { + rocprofiler::stopContext(state->codeObjectContext); + state->codeObjectStarted = false; + } + } + rocprofiler::flushBuffer(state->kernelBuffer); + if (state->finalizeFunc && state->clientId) { + state->finalizeFunc(*state->clientId); + } +} + +rocprofiler_tool_configure_result_t * +protonConfigure(uint32_t version, const char *runtimeVersion, uint32_t priority, + rocprofiler_client_id_t *id) { + auto &state = getRuntimeState(); + id->name = "ProtonRocprofSDK"; + state.clientId = id; + static rocprofiler_tool_configure_result_t config{ + sizeof(rocprofiler_tool_configure_result_t), &protonToolInit, + &protonToolFini, static_cast(&state)}; + return &config; +} + +} // namespace + +// ---- Profiler lifecycle ---- + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::doStart() { + auto &state = getRuntimeState(); + std::lock_guard lock(state.mutex); + if (!state.profilingStarted) { + rocprofiler::startContext(state.profilingContext); + state.profilingStarted = true; + } + bool nvtx = getBoolEnv("TRITON_ENABLE_NVTX", true); + state.nvtxEnabled.store(nvtx, std::memory_order_relaxed); + if (nvtx) + registerRoctxCallback(true); +} + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::doFlush() { + auto &state = getRuntimeState(); + std::ignore = hip::deviceSynchronize(); + profiler.correlation.flush( + /*maxRetries=*/100, /*sleepUs=*/10, + [&state]() { rocprofiler::flushBuffer(state.kernelBuffer); }); +} + +void RocprofSDKProfiler::RocprofSDKProfilerPimpl::doStop() { + auto &state = getRuntimeState(); + state.nvtxEnabled.store(false, std::memory_order_relaxed); + registerRoctxCallback(false); + // Keep the profiling context running. rocprofiler-sdk does not reliably + // re-intercept HIP runtime API calls after a stopContext→startContext + // cycle on the same context. The correlation ID mechanism ensures that + // kernel dispatch records without a matching active session are discarded. +} + +RocprofSDKProfiler::RocprofSDKProfiler() { + pImpl = std::make_unique(*this); + auto &state = getRuntimeState(); + state.pimpl = static_cast(pImpl.get()); + // Configure rocprofiler-sdk as soon as this singleton is constructed. + // Deferring until doStart() is unsafe: any code that fully initializes HSA + // beforehand (e.g. triton's HIP driver query at pytest collection time, + // or a torch import chain) causes rocprofiler-sdk 1.2.0 to silently skip + // kernel-dispatch buffer tracing installation on already-existing queues, + // producing an empty dispatch buffer and no per-kernel timing data. + // Construction of this singleton is triggered at libproton.so load time + // via the __attribute__((constructor)) hook below, so force_configure + // lands before any user code touches the HIP/HSA runtimes. + if (!state.configured) { + rocprofiler::forceConfigure(&protonConfigure); + } +} + +RocprofSDKProfiler::~RocprofSDKProfiler() = default; + +namespace { +// Runs during dlopen of libproton.so (i.e. `import triton.profiler._C`). +// Touches the singleton so its constructor — which calls +// rocprofiler_force_configure — runs before any Python code executes. +// Wrapped in try/catch so non-ROCm environments (where +// librocprofiler-sdk.so cannot be dlopen'd) continue to import cleanly; +// a subsequent attempt to start a "rocprofiler" session will surface the +// error through the normal lazy-dispatch path. +__attribute__((constructor)) void protonRocprofSDKLoadHook() { + try { + (void)RocprofSDKProfiler::instance(); + } catch (...) { + // Intentionally swallowed: non-ROCm or rocprofiler-sdk unavailable. + } +} +} // namespace + +void RocprofSDKProfiler::doSetMode( + const std::vector &modeAndOptions) { + auto mode = modeAndOptions.empty() ? std::string() : modeAndOptions[0]; + if (proton::toLower(mode) == "periodic_flushing") { + detail::setPeriodicFlushingMode(periodicFlushingEnabled, + periodicFlushingFormat, modeAndOptions, + "RocprofSDKProfiler"); + } else if (!mode.empty()) { + throw std::invalid_argument( + "[PROTON] RocprofSDKProfiler: unsupported mode: " + mode); + } +} + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index f999b45e50ed..949c18e26a05 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -5,6 +5,7 @@ #include "Data/TreeData.h" #include "Profiler/Cupti/CuptiProfiler.h" #include "Profiler/Instrumentation/InstrumentationProfiler.h" +#include "Profiler/RocprofSDK/RocprofSDKProfiler.h" #include "Profiler/Roctracer/RoctracerProfiler.h" #include "Utility/Errors.h" #include "Utility/String.h" @@ -16,6 +17,8 @@ namespace { Profiler *makeProfiler(const std::string &name) { if (proton::toLower(name) == "cupti") { return &CuptiProfiler::instance(); + } else if (proton::toLower(name) == "rocprofiler") { + return &RocprofSDKProfiler::instance(); } else if (proton::toLower(name) == "roctracer") { return &RoctracerProfiler::instance(); } else if (proton::toLower(name) == "instrumentation") { diff --git a/third_party/proton/proton/__init__.py b/third_party/proton/proton/__init__.py index 2ea3e401c9b6..e02317fd5324 100644 --- a/third_party/proton/proton/__init__.py +++ b/third_party/proton/proton/__init__.py @@ -1,4 +1,24 @@ # ruff: noqa + + +# When running in a TheRock virtual environment, ROCm libraries live under +# _rocm_sdk_core/lib/ which isn't on LD_LIBRARY_PATH. Point the C++ backend +# at the correct directory so dlopen() can find librocprofiler-sdk.so et al. +def _ensure_rocm_lib_env(): + import os + if os.environ.get("TRITON_ROCPROFILER_SDK_LIB_PATH"): + return + try: + import _rocm_sdk_core + lib_dir = os.path.join(os.path.dirname(_rocm_sdk_core.__file__), "lib") + if os.path.isdir(lib_dir): + os.environ["TRITON_ROCPROFILER_SDK_LIB_PATH"] = lib_dir + except ImportError: + pass + + +_ensure_rocm_lib_env() + from .scope import scope, cpu_timed_scope, enter_scope, exit_scope from .state import state, enter_state, exit_state, metadata_state from .profile import ( diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 5288187fcb10..7e90062fb53f 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -16,10 +16,9 @@ def _select_backend() -> str: backend = triton.runtime.driver.active.get_current_target().backend if backend == "cuda": return "cupti" - elif backend == "hip": - return "roctracer" - else: - raise ValueError("No backend is available for the current target.") + if backend == "hip": + return "rocprofiler" + raise ValueError("No backend is available for the current target.") def _get_mode_str(backend: str, mode: Optional[Union[str, BaseMode]]) -> str: @@ -30,7 +29,7 @@ def _get_mode_str(backend: str, mode: Optional[Union[str, BaseMode]]) -> str: def _check_env(backend: str) -> None: - if backend == "roctracer": + if backend in ("rocprofiler", "roctracer"): hip_device_envs = ["HIP_VISIBLE_DEVICES", "CUDA_VISIBLE_DEVICES"] for env in hip_device_envs: if getenv(env, None) is not None: @@ -85,13 +84,16 @@ def start( Available options are ["tree", "trace"]. Defaults to "tree". backend (str, optional): The backend to use for profiling. - Available options are [None, "cupti", "roctracer", "instrumentation"]. + Available options are [None, "cupti", "rocprofiler", "roctracer", "instrumentation"]. Defaults to None, which automatically selects the backend matching the current active runtime. + On AMD GPUs, "rocprofiler" is preferred and will fall back to "roctracer" if + rocprofiler-sdk is not available. mode (Union[str, BaseMode], optional): The "mode" to use for profiling, which is specific to the backend. Can be a string or an instance of BaseMode (or any subclass thereof). Defaults to None. For "cupti", available options are [None, "pcsampling", "periodic_flushing"]. - For "roctracer", available options are ["periodic_flushing"]. + For "rocprofiler", available options are [None, "periodic_flushing"]. + For "roctracer", available options are [None, "periodic_flushing"]. For "instrumentation", available options are [None]. Each mode has a set of control knobs following with the mode name. For example, "periodic_flushing" mode has a knob: diff --git a/third_party/proton/proton/proton.py b/third_party/proton/proton/proton.py index a7689288da0c..c832006c4918 100644 --- a/third_party/proton/proton/proton.py +++ b/third_party/proton/proton/proton.py @@ -16,7 +16,7 @@ def parse_arguments(): """, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("-n", "--name", type=str, help="Name of the profiling session") parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, - choices=["cupti", "roctracer", "instrumentation"]) + choices=["cupti", "rocprofiler", "instrumentation"]) parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow", choices=["shadow", "python"]) parser.add_argument("-m", "--mode", type=str, help="Profiling mode", default=None) diff --git a/third_party/proton/test/test_api.py b/third_party/proton/test/test_api.py index 67dcd3b604a7..e19d627bafe3 100644 --- a/third_party/proton/test/test_api.py +++ b/third_party/proton/test/test_api.py @@ -96,7 +96,7 @@ def test_profile_mode(tmp_path: pathlib.Path): try: proton.start(str(temp_file0.with_suffix("")), mode="pcsampling") except Exception as e: - assert "RoctracerProfiler: unsupported mode: pcsampling" in str(e) + assert "unsupported mode: pcsampling" in str(e) finally: proton.finalize() else: diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index 7e1d438d3b99..ba64e720322a 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -2,6 +2,7 @@ import subprocess import json import pathlib +import sys def test_help(): @@ -18,7 +19,7 @@ def test_exec(mode, tmp_path: pathlib.Path): if mode == "script": subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) elif mode == "python": - subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + subprocess.check_call([sys.executable, "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) elif mode == "pytest": subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], diff --git a/third_party/proton/tutorials/dynamic-net.py b/third_party/proton/tutorials/dynamic-net.py index 8a933d200f33..63f11a57b960 100644 --- a/third_party/proton/tutorials/dynamic-net.py +++ b/third_party/proton/tutorials/dynamic-net.py @@ -85,7 +85,7 @@ def run(): argparser.add_argument("--profile", action="store_true") argparser.add_argument("--engine", default="torch", choices=["torch", "torchinductor"]) argparser.add_argument("--context", default="shadow", choices=["shadow", "python"]) -argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer"]) +argparser.add_argument("--backend", default=None, choices=["cupti", "rocprofiler"]) argparser.add_argument("--mode", default=None) args = argparser.parse_args()