diff --git a/.azure-pipelines/ut.yml b/.azure-pipelines/ut.yml index 78b679e8d..40a648809 100644 --- a/.azure-pipelines/ut.yml +++ b/.azure-pipelines/ut.yml @@ -79,3 +79,85 @@ jobs: export PATH=/usr/local/mpi/bin:$PATH mpirun -tag-output -x MSCCLPP_HOME=$(System.DefaultWorkingDirectory) -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x workingDirectory: '$(System.DefaultWorkingDirectory)' + +- job: UnitTestWithNpKit + timeoutInMinutes: 30 + pool: + name: mscclpp + strategy: + matrix: + cuda11: + containerImage: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-cuda11.8 + cuda12: + containerImage: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-cuda12.2 + + container: + image: $[ variables['containerImage'] ] + options: --privileged --ipc=host --gpus=all --ulimit memlock=-1:-1 + + steps: + - task: Bash@3 + name: Build + displayName: Build + inputs: + targetType: 'inline' + script: | + mkdir build && cd build + cmake -DCMAKE_BUILD_TYPE=Release -DNPKIT_FLAGS="-DENABLE_NPKIT -DENABLE_NPKIT_EVENT_TIME_SYNC_CPU -DENABLE_NPKIT_EVENT_TIME_SYNC_GPU -DENABLE_NPKIT_EVENT_EXECUTOR_INIT_ENTRY -DENABLE_NPKIT_EVENT_EXECUTOR_INIT_EXIT -DENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY -DENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT" .. + make -j + workingDirectory: '$(System.DefaultWorkingDirectory)' + + - task: Bash@3 + name: LockGPUClock + displayName: Lock GPU clock frequency + inputs: + targetType: 'inline' + script: | + sudo nvidia-smi -pm 1 + for i in $(seq 0 $(( $(nvidia-smi -L | wc -l) - 1 ))); do + sudo nvidia-smi -ac $(nvidia-smi --query-gpu=clocks.max.memory,clocks.max.sm --format=csv,noheader,nounits -i $i | sed 's/\ //') -i $i + done + workingDirectory: '$(System.DefaultWorkingDirectory)' + + - task: Bash@3 + name: MpUnitTests + displayName: Run mscclpp multi-process unit tests + inputs: + targetType: 'inline' + script: | + set -e + rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output + export PATH=/usr/local/mpi/bin:$PATH + export NPKIT_DUMP_DIR=./npkit_dump + mpirun -tag-output -np 2 ./build/test/mp_unit_tests --gtest_filter="ExecutorTest.TwoNodesAllreduce" + python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output + grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json + workingDirectory: '$(System.DefaultWorkingDirectory)' + + - task: Bash@3 + name: PyTests + displayName: Run pytests + inputs: + targetType: 'inline' + script: | + set -e + rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output + export PATH=/usr/local/mpi/bin:$PATH + export NPKIT_DUMP_DIR=./npkit_dump + mpirun -tag-output -x MSCCLPP_HOME=$(System.DefaultWorkingDirectory) -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce.json' + python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output + grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_SIGNAL_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_WAIT_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_READ_REDUCE_COPY_SEND_ENTRY ./npkit_output/npkit_event_trace.json + rm -rf ./npkit_dump && mkdir ./npkit_dump && rm -rf ./npkit_output && mkdir ./npkit_output + mpirun -tag-output -x MSCCLPP_HOME=$(System.DefaultWorkingDirectory) -np 8 python3 -m pytest ./python/test/test_mscclpp.py -x -k 'test_executor[allreduce_packet.json' + python3 ./tools/npkit/npkit_trace_generator.py --npkit_dump_dir=./npkit_dump --npkit_event_header_path=./include/mscclpp/npkit/npkit_event.hpp --output_dir=./npkit_output + grep -q NPKIT_EVENT_EXECUTOR_INIT_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_COPY_PACKET_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_PUT_PACKET_ENTRY ./npkit_output/npkit_event_trace.json + grep -q NPKIT_EVENT_EXECUTOR_REDUCE_SEND_PACKET_ENTRY ./npkit_output/npkit_event_trace.json + workingDirectory: '$(System.DefaultWorkingDirectory)' diff --git a/CMakeLists.txt b/CMakeLists.txt index 31525f9c9..58918eec5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) # Options option(ENABLE_TRACE "Enable tracing" OFF) -option(USE_NPKIT "Use NPKIT" ON) option(BUILD_TESTS "Build tests" ON) option(BUILD_PYTHON_BINDINGS "Build Python bindings" ON) option(USE_CUDA "Use NVIDIA/CUDA." OFF) @@ -119,8 +118,8 @@ endif() if(ENABLE_TRACE) target_compile_definitions(mscclpp_obj PRIVATE ENABLE_TRACE) endif() -if(USE_NPKIT) - target_compile_definitions(mscclpp_obj PRIVATE ENABLE_NPKIT) +if(NPKIT_FLAGS) + target_compile_definitions(mscclpp_obj PRIVATE ${NPKIT_FLAGS}) endif() # libmscclpp diff --git a/include/mscclpp/npkit/npkit.hpp b/include/mscclpp/npkit/npkit.hpp new file mode 100644 index 000000000..d2f98a7c5 --- /dev/null +++ b/include/mscclpp/npkit/npkit.hpp @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef NPKIT_H_ +#define NPKIT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#if defined(__HIP_PLATFORM_AMD__) +#define NPKIT_GET_GPU_TIMESTAMP wall_clock64 +#else +#define NPKIT_GET_GPU_TIMESTAMP clock64 +#endif + +#define NPKIT_SHM_NUM_EVENTS 64 + +class NpKit { + public: + static const uint64_t kNumGpuEventBuffers = 1024; + + static const uint64_t kNumCpuEventBuffers = 64; + + static void Init(int rank); + + static void Dump(const std::string& dump_dir); + + static void Shutdown(); + + static NpKitEventCollectContext* GetGpuEventCollectContexts(); + +#if defined(MSCCLPP_DEVICE_COMPILE) + static MSCCLPP_DEVICE_INLINE void CollectGpuEventShm(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, + NpKitEvent* event_buffer, uint64_t* event_buffer_head) { + if (*event_buffer_head < NPKIT_SHM_NUM_EVENTS) { + if (threadIdx.x == 0) { + NpKitEvent& event = event_buffer[*event_buffer_head]; + event.fields.type = type; + event.fields.size = size; + event.fields.rsvd = rsvd; + event.fields.timestamp = timestamp; + } + (*event_buffer_head)++; + } + } + + static MSCCLPP_DEVICE_INLINE void StoreGpuEventShm(NpKitEventCollectContext* npKitEventCollectContexts, + NpKitEvent* event_buffer, uint64_t event_buffer_head) { +#if defined(MSCCLPP_DEVICE_HIP) + __synclds(); +#else // !defined(MSCCLPP_DEVICE_HIP) + __syncthreads(); +#endif // !defined(MSCCLPP_DEVICE_HIP) + NpKitEventCollectContext* npKitCtx = npKitEventCollectContexts + blockIdx.x; + NpKitEvent* global_event_buffer = npKitCtx->event_buffer; + uint64_t global_event_buffer_head = npKitCtx->event_buffer_head; + for (size_t i = threadIdx.x; i < event_buffer_head * sizeof(NpKitEvent) / sizeof(int4); i += blockDim.x) { + ((int4*)(global_event_buffer + global_event_buffer_head))[i] = ((int4*)event_buffer)[i]; + } + if (threadIdx.x == 0) { + npKitCtx->event_buffer_head += event_buffer_head; + } + } +#endif + + static void CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id); + + static uint64_t* GetCpuTimestamp(); + + private: + static void CpuTimestampUpdateThread(); + + // 64K * 1024 * 16B = 1GB per GPU + static const uint64_t kMaxNumGpuEventsPerBuffer = 1ULL << 16; + + // 64K * 2 (send/recv) * (1024/64) = 2M, 2M * 64 * 16B = 2GB per CPU + static const uint64_t kMaxNumCpuEventsPerBuffer = 1ULL << 21; + + static std::vector> gpu_event_buffers_; + static std::vector> cpu_event_buffers_; + + static mscclpp::UniqueCudaPtr gpu_collect_contexts_; + static std::unique_ptr cpu_collect_contexts_; + + static uint64_t rank_; + + static mscclpp::UniqueCudaHostPtr cpu_timestamp_; + static std::unique_ptr cpu_timestamp_update_thread_; + static volatile bool cpu_timestamp_update_thread_should_stop_; +}; + +#endif diff --git a/include/mscclpp/npkit/npkit_event.hpp b/include/mscclpp/npkit/npkit_event.hpp new file mode 100644 index 000000000..da0206c0f --- /dev/null +++ b/include/mscclpp/npkit/npkit_event.hpp @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef NPKIT_EVENT_H_ +#define NPKIT_EVENT_H_ + +#define NPKIT_EVENT_INVALID 0x0 + +#define NPKIT_EVENT_TIME_SYNC_GPU 0x1 +#define NPKIT_EVENT_TIME_SYNC_CPU 0x2 + +#define NPKIT_EVENT_EXECUTOR_INIT_ENTRY 0x3 +#define NPKIT_EVENT_EXECUTOR_INIT_EXIT 0x4 + +#define NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY 0x5 +#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x15 + +#endif diff --git a/src/npkit/npkit_struct.h b/include/mscclpp/npkit/npkit_struct.hpp similarity index 98% rename from src/npkit/npkit_struct.h rename to include/mscclpp/npkit/npkit_struct.hpp index 62b417f24..44de35357 100644 --- a/src/npkit/npkit_struct.h +++ b/include/mscclpp/npkit/npkit_struct.hpp @@ -25,4 +25,4 @@ struct NpKitEventCollectContext { #pragma pack(pop) -#endif \ No newline at end of file +#endif diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index 0acc55fc5..c9df30cf1 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -25,6 +25,7 @@ PacketType, version, is_nvls_supported, + npkit, ) __version__ = version() diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 8dc9df57b..a44256a0d 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -22,6 +22,7 @@ extern void register_utils(nb::module_& m); extern void register_numa(nb::module_& m); extern void register_nvls(nb::module_& m); extern void register_executor(nb::module_& m); +extern void register_npkit(nb::module_& m); template void def_nonblocking_future(nb::handle& m, const std::string& typestr) { @@ -189,4 +190,5 @@ NB_MODULE(_mscclpp, m) { register_numa(m); register_nvls(m); register_executor(m); + register_npkit(m); } diff --git a/python/mscclpp/npkit_py.cpp b/python/mscclpp/npkit_py.cpp new file mode 100644 index 000000000..0557b72d8 --- /dev/null +++ b/python/mscclpp/npkit_py.cpp @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +#include + +namespace nb = nanobind; + +void register_npkit(nb::module_ &m) { + nb::module_ sub_m = m.def_submodule("npkit", "NPKit functions"); + sub_m.def("init", &NpKit::Init); + sub_m.def("dump", &NpKit::Dump); + sub_m.def("shutdown", &NpKit::Shutdown); +} diff --git a/python/test/executor_test.py b/python/test/executor_test.py index d744a4c1a..3a0bd2d74 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -7,6 +7,7 @@ Executor, ExecutionPlan, PacketType, + npkit, ) import mscclpp.comm as mscclpp_comm @@ -87,6 +88,9 @@ def main( mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) cp.cuda.Device(mscclpp_group.my_rank % mscclpp_group.nranks_per_node).use() executor = Executor(mscclpp_group.communicator) + npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR") + if npkit_dump_dir is not None: + npkit.init(mscclpp_group.my_rank) execution_plan = ExecutionPlan(execution_paln_name, execution_plan_path) cp.random.seed(seed) @@ -119,6 +123,9 @@ def main( mscclpp_group.barrier() execution_time = bench_time(100, 10, executor_func) + if npkit_dump_dir is not None: + npkit.dump(npkit_dump_dir) + npkit.shutdown() print( f"Rank: {MPI.COMM_WORLD.rank} Execution time: {execution_time} us, " f"data size: {sendbuf.nbytes} bytes data type: {dtype().dtype.name} " diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index c6014b84e..4af3ddb36 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -24,6 +24,7 @@ TcpBootstrap, Transport, is_nvls_supported, + npkit, ) import mscclpp.comm as mscclpp_comm from mscclpp.utils import KernelBuilder, pack @@ -603,6 +604,9 @@ def test_executor(mpi_group: MpiGroup, filename: str): project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) mscclpp_group = mscclpp_comm.CommGroup(mpi_group.comm) executor = Executor(mscclpp_group.communicator) + npkit_dump_dir = os.getenv("NPKIT_DUMP_DIR") + if npkit_dump_dir is not None: + npkit.init(mscclpp_group.my_rank) execution_plan = ExecutionPlan("allreduce_pairs", os.path.join(project_dir, "test", "execution-files", filename)) nelems = 1024 * 1024 @@ -629,3 +633,6 @@ def test_executor(mpi_group: MpiGroup, filename: str): ) stream.synchronize() assert cp.allclose(sendbuf, expected, atol=1e-3 * mpi_group.comm.size) + if npkit_dump_dir is not None: + npkit.dump(npkit_dump_dir) + npkit.shutdown() diff --git a/src/connection.cc b/src/connection.cc index b5fd5b9b9..fc3724c08 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -3,6 +3,9 @@ #include "connection.hpp" +#if defined(ENABLE_NPKIT) +#include +#endif #include #include #include @@ -10,7 +13,6 @@ #include "debug.h" #include "endpoint.hpp" #include "infiniband/verbs.h" -#include "npkit/npkit.h" namespace mscclpp { diff --git a/src/executor/execution_kernel.cu b/src/executor/execution_kernel.cu index 4e96af9ab..06079f439 100644 --- a/src/executor/execution_kernel.cu +++ b/src/executor/execution_kernel.cu @@ -13,19 +13,43 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo switch (dataType) { case DataType::INT32: executionKernel<<>>( - rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag); + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; case DataType::UINT32: executionKernel<<>>( - rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag); + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; case DataType::FLOAT16: executionKernel<<>>( - rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag); + rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; case DataType::FLOAT32: executionKernel<<>>( - rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag); + rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; } } diff --git a/src/executor/executor.cc b/src/executor/executor.cc index f8d9f8e83..0d5e75f04 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -261,7 +261,11 @@ struct Executor::Impl { DataType dataType, cudaStream_t stream, PacketType packetType) { static uint32_t flag = 0; int nthreadblocks = context.deviceExecutionPlans.size(); +#if defined(ENABLE_NPKIT) + size_t sharedMemSize = sizeof(DeviceExecutionPlan) + NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent); +#else size_t sharedMemSize = sizeof(DeviceExecutionPlan); +#endif switch (packetType) { case PacketType::LL16: ExecutionKernel::launchKernel( diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index a69013c5b..834e0f3f8 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -5,6 +5,9 @@ #define MSCCLPP_EXECUTION_KERNEL_HPP_ #include +#if defined(ENABLE_NPKIT) +#include +#endif #include #include #include @@ -333,10 +336,26 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T template __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* input, T* output, T* scratch, - size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag) { + size_t scratchSize, DeviceExecutionPlan* plan, uint32_t flag +#if defined(ENABLE_NPKIT) + , + NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) { +#else +) { +#endif extern __shared__ int4 sharedMem[]; int bid = blockIdx.x; int tid = threadIdx.x; +#if defined(ENABLE_NPKIT) + NpKitEvent* event_buffer = (NpKitEvent*)((char*)sharedMem + sizeof(DeviceExecutionPlan)); + uint64_t event_buffer_head = 0; +#if defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_ENTRY) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_EXIT) + uint64_t npkit_timestamp_entry = 0; + if (tid == 0) { + npkit_timestamp_entry = NPKIT_GET_GPU_TIMESTAMP(); + } +#endif +#endif DeviceExecutionPlan* localPlan = plan + bid; for (size_t i = tid; i < sizeof(DeviceExecutionPlan) / sizeof(int4); i += blockDim.x) { sharedMem[i] = ((int4*)localPlan)[i]; @@ -352,8 +371,31 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu DeviceHandle* smChannels = localPlan->channels.smChannels; DeviceHandle* proxyChannels = localPlan->channels.proxyChannels; +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_CPU) + NpKit::CollectGpuEventShm(NPKIT_EVENT_TIME_SYNC_CPU, 0, 0, *cpuTimestamp, event_buffer, &event_buffer_head); +#endif + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_TIME_SYNC_GPU) + NpKit::CollectGpuEventShm(NPKIT_EVENT_TIME_SYNC_GPU, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), event_buffer, + &event_buffer_head); +#endif + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_ENTRY) && \ + defined(ENABLE_NPKIT_EVENT_EXECUTOR_INIT_EXIT) + NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_INIT_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer, + &event_buffer_head); + NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_INIT_EXIT, 0, 0, NPKIT_GET_GPU_TIMESTAMP(), event_buffer, + &event_buffer_head); +#endif + for (int i = 0; i < nOperations; i++) { Operation& op = operations[i]; + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY) + NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY + (int)op.type, op.size, 0, NPKIT_GET_GPU_TIMESTAMP(), + event_buffer, &event_buffer_head); +#endif + if (op.type == OperationType::BARRIER) { __syncthreads(); } else if (op.type == OperationType::SIGNAL) { @@ -403,7 +445,16 @@ __global__ void executionKernel([[maybe_unused]] int rank /*for debug*/, T* inpu handleReduceSend(dst, op.dstOffset, src, op.srcOffset, tmp, op.inputOffsets, smChannels, op.outputChannelIndexes, op.outputOffsets, op.nOutputs, op.size); } + +#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT) + NpKit::CollectGpuEventShm(NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT + (int)op.type, op.size, 0, NPKIT_GET_GPU_TIMESTAMP(), + event_buffer, &event_buffer_head); +#endif } + +#if defined(ENABLE_NPKIT) + NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head); +#endif } #endif // defined(MSCCLPP_DEVICE_COMPILE) @@ -417,19 +468,43 @@ class ExecutionKernel { switch (dataType) { case DataType::INT32: executionKernel<<>>( - rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag); + rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; case DataType::UINT32: executionKernel<<>>( - rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag); + rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; case DataType::FLOAT16: executionKernel<<>>( - rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag); + rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; case DataType::FLOAT32: executionKernel<<>>( - rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag); + rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif break; } } diff --git a/src/npkit/npkit.cc b/src/npkit/npkit.cc index 466806d1f..54bac9d62 100644 --- a/src/npkit/npkit.cc +++ b/src/npkit/npkit.cc @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "npkit.h" - #include #include #include #include +#include + +#include "debug.h" uint64_t NpKit::rank_ = 0; @@ -16,41 +17,85 @@ std::vector> NpKit::cpu_event_buffers_; mscclpp::UniqueCudaPtr NpKit::gpu_collect_contexts_; std::unique_ptr NpKit::cpu_collect_contexts_; -uint64_t NpKit::cpu_base_system_timestamp_ = 0; -uint64_t NpKit::cpu_base_steady_timestamp_ = 0; + +mscclpp::UniqueCudaHostPtr NpKit::cpu_timestamp_; +std::unique_ptr NpKit::cpu_timestamp_update_thread_; +volatile bool NpKit::cpu_timestamp_update_thread_should_stop_ = false; + +void NpKit::CpuTimestampUpdateThread() { + uint64_t init_system_clock = std::chrono::system_clock::now().time_since_epoch().count(); + uint64_t init_steady_clock = std::chrono::steady_clock::now().time_since_epoch().count(); + uint64_t curr_steady_clock = 0; + volatile uint64_t* volatile_cpu_timestamp_ = cpu_timestamp_.get(); + while (!cpu_timestamp_update_thread_should_stop_) { + curr_steady_clock = std::chrono::steady_clock::now().time_since_epoch().count(); + *volatile_cpu_timestamp_ = init_system_clock + (curr_steady_clock - init_steady_clock); + } +} void NpKit::Init(int rank) { +#if defined(ENABLE_NPKIT) uint64_t i = 0; NpKitEventCollectContext ctx; ctx.event_buffer_head = 0; rank_ = rank; // Init event data structures - gpu_collect_contexts_ = mscclpp::allocUniqueCuda(kNumGpuEventBuffers); - for (i = 0; i < kNumGpuEventBuffers; i++) { + gpu_collect_contexts_ = mscclpp::allocUniqueCuda(NpKit::kNumGpuEventBuffers); + for (i = 0; i < NpKit::kNumGpuEventBuffers; i++) { gpu_event_buffers_.emplace_back(mscclpp::allocUniqueCuda(kMaxNumGpuEventsPerBuffer)); ctx.event_buffer = gpu_event_buffers_[i].get(); mscclpp::memcpyCuda(gpu_collect_contexts_.get() + i, &ctx, 1); } - cpu_collect_contexts_ = std::make_unique(kNumCpuEventBuffers); - for (i = 0; i < kNumCpuEventBuffers; i++) { + cpu_collect_contexts_ = std::make_unique(NpKit::kNumCpuEventBuffers); + for (i = 0; i < NpKit::kNumCpuEventBuffers; i++) { cpu_event_buffers_.emplace_back(std::make_unique(kMaxNumCpuEventsPerBuffer)); ctx.event_buffer = cpu_event_buffers_[i].get(); cpu_collect_contexts_[i] = ctx; } // Init timestamp - cpu_base_system_timestamp_ = std::chrono::system_clock::now().time_since_epoch().count(); - cpu_base_steady_timestamp_ = std::chrono::steady_clock::now().time_since_epoch().count(); + cpu_timestamp_ = mscclpp::makeUniqueCudaHost(); + volatile uint64_t* volatile_cpu_timestamp = cpu_timestamp_.get(); + *volatile_cpu_timestamp = std::chrono::system_clock::now().time_since_epoch().count(); + cpu_timestamp_update_thread_should_stop_ = false; + cpu_timestamp_update_thread_ = std::make_unique(CpuTimestampUpdateThread); +#else + WARN("NpKit::Init(%d) : MSCCLPP library was not built with NPKit enabled.", rank); +#endif +} + +#if defined(ENABLE_NPKIT) +static int GetGpuClockRateInKhz() { + int dev_id; +#if defined(__HIP_PLATFORM_AMD__) + cudaDeviceProp_t dev_prop; + char gcn_arch[256]; + MSCCLPP_CUDATHROW(cudaGetDevice(&dev_id)); + MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&dev_prop, dev_id)); + char* gcnArchNameToken = strtok(dev_prop.gcnArchName, ":"); + strcpy(gcn_arch, gcnArchNameToken); + if (strncmp("gfx94", gcn_arch, 5) == 0) + return 100000; + else + return 25000; +#else + cudaDeviceProp dev_prop; + MSCCLPP_CUDATHROW(cudaGetDevice(&dev_id)); + MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&dev_prop, dev_id)); + return dev_prop.clockRate; +#endif } +#endif void NpKit::Dump(const std::string& dump_dir) { +#if defined(ENABLE_NPKIT) uint64_t i = 0; std::string dump_file_path; // Dump CPU events - for (i = 0; i < kNumCpuEventBuffers; i++) { + for (i = 0; i < NpKit::kNumCpuEventBuffers; i++) { dump_file_path = dump_dir; dump_file_path += "/cpu_events_rank_"; dump_file_path += std::to_string(rank_); @@ -80,7 +125,7 @@ void NpKit::Dump(const std::string& dump_dir) { clock_period_den_file.close(); // Dump GPU events, reuse CPU struct - for (i = 0; i < kNumGpuEventBuffers; i++) { + for (i = 0; i < NpKit::kNumGpuEventBuffers; i++) { dump_file_path = dump_dir; dump_file_path += "/gpu_events_rank_"; dump_file_path += std::to_string(rank_); @@ -98,17 +143,21 @@ void NpKit::Dump(const std::string& dump_dir) { dump_file_path = dump_dir; dump_file_path += "/gpu_clock_rate_rank_"; dump_file_path += std::to_string(rank_); - cudaDeviceProp dev_prop; - int dev; - MSCCLPP_CUDATHROW(cudaGetDevice(&dev)); - MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&dev_prop, dev)); - std::string clock_rate_str = std::to_string(dev_prop.clockRate); + std::string clock_rate_str = std::to_string(GetGpuClockRateInKhz()); auto gpu_clock_rate_file = std::fstream(dump_file_path, std::ios::out); gpu_clock_rate_file.write(clock_rate_str.c_str(), clock_rate_str.length()); gpu_clock_rate_file.close(); +#else + WARN("NpKit::Dump(%s) : MSCCLPP library was not built with NPKit enabled.", dump_dir.c_str()); +#endif } void NpKit::Shutdown() { +#if defined(ENABLE_NPKIT) + // Stop CPU timestamp updating thread + cpu_timestamp_update_thread_should_stop_ = true; + cpu_timestamp_update_thread_->join(); + // Free CPU event data structures cpu_event_buffers_.clear(); cpu_collect_contexts_.reset(); @@ -116,6 +165,11 @@ void NpKit::Shutdown() { // Free GPU event data structures gpu_event_buffers_.clear(); gpu_collect_contexts_.reset(); + + // Free timestamp + cpu_timestamp_update_thread_.reset(); + cpu_timestamp_.reset(); +#endif } NpKitEventCollectContext* NpKit::GetGpuEventCollectContexts() { return gpu_collect_contexts_.get(); } @@ -132,7 +186,4 @@ void NpKit::CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t } } -uint64_t NpKit::GetCpuTimestamp() { - uint64_t cpu_curr_steady_timestamp_ = std::chrono::steady_clock::now().time_since_epoch().count(); - return cpu_base_steady_timestamp_ + (cpu_curr_steady_timestamp_ - cpu_base_steady_timestamp_); -} +uint64_t* NpKit::GetCpuTimestamp() { return cpu_timestamp_.get(); } diff --git a/src/npkit/npkit.h b/src/npkit/npkit.h deleted file mode 100644 index 21ba928ae..000000000 --- a/src/npkit/npkit.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef NPKIT_H_ -#define NPKIT_H_ - -#include -#include -#include - -#include "npkit_event.h" -#include "npkit_struct.h" - -class NpKit { - public: - static const uint64_t kNumGpuEventBuffers = 512; - - static const uint64_t kNumCpuEventBuffers = 32; - - static void Init(int rank); - - static void Dump(const std::string& dump_dir); - - static void Shutdown(); - - static NpKitEventCollectContext* GetGpuEventCollectContexts(); - -#ifdef __CUDACC__ - static inline __device__ void CollectGpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, - NpKitEventCollectContext* ctx) { - uint64_t event_buffer_head = ctx->event_buffer_head; - if (event_buffer_head < kMaxNumGpuEventsPerBuffer) { - NpKitEvent& event = ctx->event_buffer[event_buffer_head]; - event.fields.type = type; - event.fields.size = size; - event.fields.rsvd = rsvd; - event.fields.timestamp = timestamp; - ctx->event_buffer_head++; - } - } -#endif // __CUDACC__ - - static void CollectCpuEvent(uint8_t type, uint32_t size, uint32_t rsvd, uint64_t timestamp, int channel_id); - - static uint64_t GetCpuTimestamp(); - - private: - // 64K * 512 * 16B = 512MB per GPU - static const uint64_t kMaxNumGpuEventsPerBuffer = 1ULL << 16; - - // 64K * 2 (send/recv) * (512/32) = 2M, 2M * 32 * 16B = 1GB per CPU - static const uint64_t kMaxNumCpuEventsPerBuffer = 1ULL << 21; - - static std::vector> gpu_event_buffers_; - static std::vector> cpu_event_buffers_; - - static mscclpp::UniqueCudaPtr gpu_collect_contexts_; - static std::unique_ptr cpu_collect_contexts_; - - static uint64_t cpu_base_system_timestamp_; - static uint64_t cpu_base_steady_timestamp_; - - static uint64_t rank_; -}; - -#endif \ No newline at end of file diff --git a/src/npkit/npkit_event.h b/src/npkit/npkit_event.h deleted file mode 100644 index f17e71363..000000000 --- a/src/npkit/npkit_event.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef NPKIT_EVENT_H_ -#define NPKIT_EVENT_H_ - -#define NPKIT_EVENT_INVALID 0x0 - -#define NPKIT_EVENT_TIME_SYNC_GPU 0x1 -#define NPKIT_EVENT_TIME_SYNC_CPU 0x2 - -#define NPKIT_EVENT_SM_REDUCE_ENTRY 0x3 -#define NPKIT_EVENT_SM_REDUCE_EXIT 0x4 - -#define NPKIT_EVENT_IB_SEND_DATA_ENTRY 0x5 -#define NPKIT_EVENT_IB_SEND_FLAG_ENTRY 0x6 -#define NPKIT_EVENT_IB_SEND_EXIT 0x7 - -#define NPKIT_EVENT_DMA_SEND_DATA_ENTRY 0x8 -#define NPKIT_EVENT_DMA_SEND_FLAG_ENTRY 0x9 -#define NPKIT_EVENT_DMA_SEND_EXIT 0xA - -#endif \ No newline at end of file diff --git a/test/executor_test.cc b/test/executor_test.cc index a30691dde..24796dd4b 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -74,11 +75,13 @@ double benchTime(int rank, std::shared_ptr bootstrap, std::s } int main(int argc, char* argv[]) { - if (argc != 5) { + if (argc != 7) { std::cerr << "Usage: " << argv[0] << " " << " " << " " - << " " << std::endl; + << " " + << " " + << " " << std::endl; return 1; } @@ -93,6 +96,9 @@ int main(int argc, char* argv[]) { const std::string executionPlanName = argv[2]; const std::string executionPlanPath = argv[3]; const int nthreadsPerBlock = std::stoi(argv[4]); + const int niters = std::stoi(argv[5]); + const int ngraphIters = std::stoi(argv[6]); + const char* npkitDumpDir = getenv("NPKIT_DUMP_DIR"); std::shared_ptr bootstrap; mscclpp::UniqueId id; @@ -103,11 +109,22 @@ int main(int argc, char* argv[]) { std::shared_ptr communicator = std::make_shared(bootstrap); std::shared_ptr executor = std::make_shared(communicator); + if (npkitDumpDir != nullptr) { + NpKit::Init(rank); + } + mscclpp::ExecutionPlan plan(executionPlanName, executionPlanPath); std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); std::vector dataHost(bufferSize / sizeof(int), rank); MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice)); - double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, 200, 20); + double deltaSec = + benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, nthreadsPerBlock, niters, ngraphIters); + + if (npkitDumpDir != nullptr) { + NpKit::Dump(npkitDumpDir); + NpKit::Shutdown(); + } + std::cout << "Rank " << rank << ": " << bufferSize << " bytes " << deltaSec * 1.e6 << " us" << std::endl; MPI_Finalize(); return 0; diff --git a/test/mp_unit/executor_tests.cc b/test/mp_unit/executor_tests.cc index fb1d104be..5baa2b67a 100644 --- a/test/mp_unit/executor_tests.cc +++ b/test/mp_unit/executor_tests.cc @@ -4,6 +4,7 @@ #include #include +#include #include "mp_unit_tests.hpp" @@ -30,9 +31,17 @@ void ExecutorTest::SetUp() { bootstrap->initialize(id); std::shared_ptr communicator = std::make_shared(bootstrap); executor = std::make_shared(communicator); + npkitDumpDir = getenv("NPKIT_DUMP_DIR"); + if (npkitDumpDir != nullptr) { + NpKit::Init(gEnv->rank); + } } void ExecutorTest::TearDown() { + if (npkitDumpDir != nullptr) { + NpKit::Dump(npkitDumpDir); + NpKit::Shutdown(); + } executor.reset(); MultiProcessTest::TearDown(); } diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index e13a05104..8afa8e917 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -170,5 +170,6 @@ class ExecutorTest : public MultiProcessTest { void TearDown() override; std::shared_ptr executor; + const char* npkitDumpDir; }; #endif // MSCCLPP_MP_UNIT_TESTS_HPP_ diff --git a/tools/npkit/npkit_trace_generator.py b/tools/npkit/npkit_trace_generator.py index 4f2bc1b5f..8c15a3ac0 100644 --- a/tools/npkit/npkit_trace_generator.py +++ b/tools/npkit/npkit_trace_generator.py @@ -2,12 +2,34 @@ # Licensed under the MIT License. import argparse -import json import os +import json + +from queue import Queue def parse_npkit_event_header(npkit_event_header_path): npkit_event_def = {"id_to_type": {}, "type_to_id": {}} + executor_ops = [ + "BARRIER", + "PUT", + "PUT_PACKET", + "GET", + "COPY", + "COPY_PACKET", + "SIGNAL", + "WAIT", + "FLUSH", + "REDUCE", + "REDUCE_PACKET", + "REDUCE_SEND", + "REDUCE_SEND_PACKET", + "READ_REDUCE_COPY", + "READ_REDUCE_COPY_SEND", + ] + executor_op_to_offset = {} + for executor_op in executor_ops: + executor_op_to_offset[executor_op] = len(executor_op_to_offset) with open(npkit_event_header_path, "r") as f: lines = [x.strip() for x in f.readlines() if len(x.strip()) != 0] line_idx = 0 @@ -17,23 +39,22 @@ def parse_npkit_event_header(npkit_event_header_path): if len(fields) == 3: event_type = fields[1] event_id = int(fields[2], 0) - npkit_event_def["type_to_id"][event_type] = event_id - npkit_event_def["id_to_type"][event_id] = event_type + if lines[line_idx].startswith("#define NPKIT_EVENT_EXECUTOR_OP_BASE"): + for executor_op in executor_op_to_offset: + real_event_id = event_id + executor_op_to_offset[executor_op] + if "ENTRY" in lines[line_idx]: + event_type = "NPKIT_EVENT_EXECUTOR_%s_ENTRY" % executor_op + elif "EXIT" in lines[line_idx]: + event_type = "NPKIT_EVENT_EXECUTOR_%s_EXIT" % executor_op + npkit_event_def["type_to_id"][event_type] = real_event_id + npkit_event_def["id_to_type"][real_event_id] = event_type + else: + npkit_event_def["type_to_id"][event_type] = event_id + npkit_event_def["id_to_type"][event_id] = event_type line_idx += 1 return npkit_event_def -def trim_event_name(event_type): - list_event_type_name = event_type.split("_") - if "NPKIT" in list_event_type_name: - list_event_type_name.remove("NPKIT") - if "EVENT" in list_event_type_name: - list_event_type_name.remove("EVENT") - if "ENTRY" in list_event_type_name: - list_event_type_name.remove("ENTRY") - return "_".join(list_event_type_name) - - def parse_gpu_clock_scale(gpu_clock_file_path): with open(gpu_clock_file_path, "r") as f: freq_in_khz = f.read() @@ -103,7 +124,7 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo event_type_to_seq[event_type] = 0 gpu_events[-1].update( { - "name": trim_event_name(event_type), + "name": event_type, "cat": "GPU", "args": { "rank": rank, @@ -116,12 +137,11 @@ def parse_gpu_event_file(npkit_dump_dir, npkit_event_def, rank, buf_idx, gpu_clo ) event_type_to_seq[event_type] += 1 else: - gpu_events[-1]["args"] = { - "size": parsed_gpu_event["size"], - "rsvd": parsed_gpu_event["rsvd"], - } + gpu_events[-1]["args"] = {"size": parsed_gpu_event["size"], "rsvd": parsed_gpu_event["rsvd"]} delta_time = gpu_events[-1]["ts"] - gpu_events[-2]["ts"] - gpu_events[-1]["args"]["bw (GB/s)"] = gpu_events[-1]["args"]["size"] / delta_time / 1e3 + gpu_events[-1]["args"]["bw (GB/s)"] = ( + 0.0 if delta_time == 0.0 else gpu_events[-1]["args"]["size"] / delta_time / 1e3 + ) raw_content_idx += raw_event_size return gpu_events @@ -133,7 +153,7 @@ def parse_cpu_event_file(npkit_dump_dir, npkit_event_def, rank, channel, cpu_clo event_type_to_seq = {} fiber_is_usable = [] - fiber_open_info = [] + fiber_open_ts = [] slot_to_fiber_id = {} channel_shift = 1000 @@ -156,17 +176,16 @@ def parse_cpu_event_file(npkit_dump_dir, npkit_event_def, rank, channel, cpu_clo fiber_id += 1 if fiber_id == len(fiber_is_usable): fiber_is_usable.append(True) - fiber_open_info.append({"ts": 0.0, "size": 0}) + fiber_open_ts.append(0.0) slot_to_fiber_id[slot] = fiber_id - fiber_open_info[fiber_id]["ts"] = cpu_events[-1]["ts"] - fiber_open_info[fiber_id]["size"] = parsed_cpu_event["size"] + fiber_open_ts[fiber_id] = cpu_events[-1]["ts"] fiber_is_usable[fiber_id] = False if event_type not in event_type_to_seq: event_type_to_seq[event_type] = 0 cpu_events[-1].update( { - "name": trim_event_name(event_type), + "name": event_type, "cat": "CPU", "args": { "rank": rank, @@ -182,16 +201,14 @@ def parse_cpu_event_file(npkit_dump_dir, npkit_event_def, rank, channel, cpu_clo # Close fiber event fiber_id = slot_to_fiber_id[slot] slot_to_fiber_id.pop(slot) - last_ts = fiber_open_info[fiber_id]["ts"] - last_size = fiber_open_info[fiber_id]["size"] + last_ts = fiber_open_ts[fiber_id] fiber_is_usable[fiber_id] = True delta_time = max(0.001, cpu_events[-1]["ts"] - last_ts) - cpu_events[-1]["args"] = { - "size_1": parsed_cpu_event["size"], - "size": max(last_size, parsed_cpu_event["size"]), - } - cpu_events[-1]["args"]["bw (GB/s)"] = cpu_events[-1]["args"]["size"] / delta_time / 1e3 + cpu_events[-1]["args"] = {"size": parsed_cpu_event["size"]} + cpu_events[-1]["args"]["bw (GB/s)"] = ( + 0.0 if delta_time == 0.0 else cpu_events[-1]["args"]["size"] / delta_time / 1e3 + ) cpu_events[-1]["tid"] = fiber_id + (channel + 1) * channel_shift @@ -239,12 +256,7 @@ def convert_npkit_dump_to_trace(npkit_dump_dir, output_dir, npkit_event_def): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--npkit_dump_dir", type=str, required=True, help="NPKit dump directory.") - parser.add_argument( - "--npkit_event_header_path", - type=str, - required=True, - help="Path to npkit_event.h.", - ) + parser.add_argument("--npkit_event_header_path", type=str, required=True, help="Path to npkit_event.h.") parser.add_argument("--output_dir", type=str, required=True, help="Path to output directory.") args = parser.parse_args()