From 0bc3c3e574b896f591aba2f049d9db02f138753f Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Sat, 22 Apr 2023 00:35:25 +0000 Subject: [PATCH 01/54] Core API teasing out WIP --- src/epoch.cc | 22 ++ src/include/channel.hpp | 295 ++++++++++++++++++++ src/include/epoch.hpp | 52 ++++ src/include/mscclpp.hpp | 477 +++++---------------------------- src/include/proxy.hpp | 39 +++ src/include/registered_ptr.hpp | 40 +++ 6 files changed, 512 insertions(+), 413 deletions(-) create mode 100644 src/epoch.cc create mode 100644 src/include/channel.hpp create mode 100644 src/include/epoch.hpp create mode 100644 src/include/proxy.hpp create mode 100644 src/include/registered_ptr.hpp diff --git a/src/epoch.cc b/src/epoch.cc new file mode 100644 index 000000000..1fee307ea --- /dev/null +++ b/src/epoch.cc @@ -0,0 +1,22 @@ +#include "epoch.hpp" +#include "checks.hpp" + +namespace mscclpp { + +struct Epoch::Impl { + DeviceEpoch deviceEpoch; + + Impl() { + MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.localSignalEpochId, 1)); + MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.waitEpochId, 1)); + } + + ~Impl() { + MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.localSignalEpochId)); + MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.waitEpochId)); + } +}; + +Epoch::Epoch() : pimpl(std::make_unique()) {} + +} // namespace mscclpp \ No newline at end of file diff --git a/src/include/channel.hpp b/src/include/channel.hpp new file mode 100644 index 000000000..cb1931b07 --- /dev/null +++ b/src/include/channel.hpp @@ -0,0 +1,295 @@ +#ifndef MSCCLPP_CHANNEL_HPP_ +#define MSCCLPP_CHANNEL_HPP_ + +#include "mscclpp.hpp" +#include "proxy.hpp" + +namespace mscclpp { + +// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered. +// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem. +#define MSCCLPP_PROXY_FIFO_SIZE 128 +#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4 + +using ChannelTriggerType = uint64_t; +const ChannelTriggerType channelTriggerData = 0x1; +const ChannelTriggerType channelTriggerFlag = 0x2; +const ChannelTriggerType channelTriggerSync = 0x4; + +// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles +// mapping to the actual +using BufferHandle = uint32_t; + +#define MSCCLPP_BITS_SIZE 32 +#define MSCCLPP_BITS_OFFSET 32 +#define MSCCLPP_BITS_BUFFER_HANDLE 8 +#define MSCCLPP_BITS_TYPE 3 +#define MSCCLPP_BITS_CONNID 10 + +// this is the basic structure of each work element in the fifo +// the summation of number of bits must be 128 or less +union ChannelTrigger { + ProxyTrigger value; + struct + { + // first 64 bits: value[0] + uint64_t size : MSCCLPP_BITS_SIZE; + uint64_t srcOffset : MSCCLPP_BITS_OFFSET; + uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + // second 64 bits: value[1] + uint64_t dstOffset : MSCCLPP_BITS_OFFSET; + uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; + uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; + uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t connId : MSCCLPP_BITS_CONNID; + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment + } fields; + +#ifdef __CUDACC__ + __device__ ChannelTrigger() {} + __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} + __device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) { + value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); + value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset); + } +#endif // __CUDACC__ +}; + +struct ConnectionEpoch { +#ifdef __CUDACC__ + __forceinline__ __device__ void wait() + { + (*waitEpochId) += 1; + while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) + ; + } + + __forceinline__ __device__ void epochIncrement() + { + *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + } +#endif // __CUDACC__ + + SignalEpochId* localSignalEpochId; + // used by the signal() function directly from gpu + SignalEpochId* remoteSignalEpochId; + + // every wait(), increments this and then the gpu waits for either: + // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread + // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread + uint64_t* waitEpochId; +}; + +class HostConnection { + struct Impl; +public: + /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ + HostConnection(std::unique_ptr); + + ~HostConnection(); + + void write() + + int getId(); + + /* Get the number of times registerBuffer(...) was called. + * + * Returns: the number of buffers registered + */ + int numLocalBuffers(); + + /* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index + * + * Inputs: + * index: the index of the handle to get + * + * Returns: a handle to the buffer + */ + BufferHandle getLocalBuffer(int index); + + /* Get the number of times registerBuffer(...) was called on the remote peer. + * + * Returns: the number of buffers registered on the remote peer + */ + int numRemoteBuffers(); + + /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index + * + * Inputs: + * index: the index of the handle to get + * + * Returns: a handle to the buffer on the remote peer + */ + BufferHandle getRemoteBuffer(int index); + + ConnectionEpoch getEpoch(); + + DeviceProxyFifo getDeviceFifo(); + + void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); + + void signal(); + + void flush(); + + void wait(); + +private: + std::unique_ptr pimpl; + friend class Communicator; +}; + +struct DeviceConnection { + DeviceConnection() = default; + + DeviceConnection(HostConnection& hostConn) + : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), + fifo(hostConn.getDeviceFifo()) {} + + DeviceConnection(const DeviceConnection& other) = default; + + DeviceConnection& operator=(DeviceConnection& other) = default; + +#ifdef __CUDACC__ + __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) + { + fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value); + } + + __forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + { + put(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void signal() + { + epochIncrement(); + fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value); + } + + __forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) + { + epochIncrement(); + fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value); + } + + __forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + { + putWithSignal(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) + { + epochIncrement(); + uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value); + while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) + ; + } + + __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + { + putWithSignalAndFlush(dst, offset, src, offset, size); + } + + __forceinline__ __device__ void flush() + { + uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value); + // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail + // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. + while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) + ; + } + + __forceinline__ __device__ void wait() + { + epoch.wait(); + } + + __forceinline__ __device__ void epochIncrement() + { + epoch.epochIncrement(); + } +#endif // __CUDACC__ + + int connectionId; + + ConnectionEpoch epoch; + + // this is a concurrent fifo which is multiple threads from the device + // can produce for and the sole proxy thread consumes it. + DeviceProxyFifo fifo; +}; + +struct SimpleDeviceConnection { + SimpleDeviceConnection() = default; + + SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) { + dst = hostConn.getRemoteBuffer(0); + src = hostConn.getLocalBuffer(0); + } + + SimpleDeviceConnection(const SimpleDeviceConnection& other) = default; + + SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default; + +#ifdef __CUDACC__ + + __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) + { + devConn.put(dst, dstOffset, src, srcOffset, size); + } + + __forceinline__ __device__ void put(uint64_t offset, uint64_t size) + { + put(offset, offset, size); + } + + __forceinline__ __device__ void signal() + { + devConn.signal(); + } + + __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) + { + devConn.putWithSignal(dst, dstOffset, src, srcOffset, size); + } + + __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) + { + putWithSignal(offset, offset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) + { + devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size); + } + + __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) + { + putWithSignalAndFlush(offset, offset, size); + } + + __forceinline__ __device__ void flush() + { + devConn.flush(); + } + + __forceinline__ __device__ void wait() + { + devConn.wait(); + } + + __forceinline__ __device__ void epochIncrement() + { + devConn.epochIncrement(); + } + +#endif // __CUDACC__ + + DeviceConnection devConn; + BufferHandle dst; + BufferHandle src; +}; + diff --git a/src/include/epoch.hpp b/src/include/epoch.hpp new file mode 100644 index 000000000..942edd8b4 --- /dev/null +++ b/src/include/epoch.hpp @@ -0,0 +1,52 @@ +#ifndef MSCCLPP_EPOCH_HPP_ +#define MSCCLPP_EPOCH_HPP_ + +#include "mscclpp.hpp" + +namespace mscclpp { + +struct alignas(16) SignalEpochId { + // every signal(), increaments this and either: + // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy + // 2) gpu thread directly writes it to remoteSignalEpochId->device + uint64_t device; + // signal() function triggers the cpu proxy thread to write to it + uint64_t proxy; +}; + +struct DeviceEpoch { +#ifdef __CUDACC__ + __forceinline__ __device__ void wait() + { + (*waitEpochId) += 1; + while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) + ; + } + + __forceinline__ __device__ void epochIncrement() + { + *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + } +#endif // __CUDACC__ + + SignalEpochId* localSignalEpochId; + SignalEpochId* remoteSignalEpochId; + uint64_t* waitEpochId; +}; + + +class Epoch { + struct Impl; + std::unique_ptr pimpl; +public: + Epoch(); + ~Epoch(); + + void signal(); + + DeviceEpoch& getDeviceEpoch(); +}; + +} // namespace mscclpp + +#endif // MSCCLPP_EPOCH_HPP_ \ No newline at end of file diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index e41e94b8b..67d400508 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -6,402 +6,85 @@ #define MSCCLPP_PATCH 0 #define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH) -// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered. -// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem. -#define MSCCLPP_PROXY_FIFO_SIZE 128 -#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4 - #include #include -#include - -#include namespace mscclpp { -struct alignas(16) SignalEpochId { - // every signal(), increaments this and either: - // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy - // 2) gpu thread directly writes it to remoteSignalEpochId->device - uint64_t device; - // signal() function triggers the cpu proxy thread to write to it - uint64_t proxy; +#define MSCCLPP_UNIQUE_ID_BYTES 128 +struct UniqueId { + char internal[MSCCLPP_UNIQUE_ID_BYTES]; }; -using ChannelTriggerType = uint64_t; -const ChannelTriggerType channelTriggerData = 0x1; -const ChannelTriggerType channelTriggerFlag = 0x2; -const ChannelTriggerType channelTriggerSync = 0x4; - -// This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles -// mapping to the actual -using BufferHandle = uint32_t; - -#define MSCCLPP_BITS_SIZE 32 -#define MSCCLPP_BITS_OFFSET 32 -#define MSCCLPP_BITS_BUFFER_HANDLE 8 -#define MSCCLPP_BITS_TYPE 3 -#define MSCCLPP_BITS_CONNID 10 - -// this is the basic structure of each work element in the fifo -// the summation of number of bits must be 128 or less -union ChannelTrigger { - ProxyTrigger value; - struct - { - // first 64 bits: value[0] - uint64_t size : MSCCLPP_BITS_SIZE; - uint64_t srcOffset : MSCCLPP_BITS_OFFSET; - uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment - // second 64 bits: value[1] - uint64_t dstOffset : MSCCLPP_BITS_OFFSET; - uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; - uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; - uint64_t type : MSCCLPP_BITS_TYPE; - uint64_t connId : MSCCLPP_BITS_CONNID; - uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment - } fields; +/* Create a unique ID for communication. Only needs to be called by one process. + * Use with mscclppCommInitRankFromId(). + * All processes need to provide the same ID to mscclppCommInitRankFromId(). + * + * Outputs: + * uniqueId: the unique ID to be created + */ +std::unique_ptr getUniqueId(); -#ifdef __CUDACC__ - __device__ ChannelTrigger() {} - __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} - __device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) { - value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); - value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset); - } -#endif // __CUDACC__ -}; +using TransportFlags = uint32_t; +const TransportFlags TransportCudaIpc = 0b1; +const TransportFlags TransportIB = 0b10; +const TransportFlags TransportIB1 = 0b100; +const TransportFlags TransportIB2 = 0b1000; +const TransportFlags TransportIB3 = 0b10000; +const TransportFlags TransportIB4 = 0b100000; +const TransportFlags TransportIB5 = 0b1000000; +const TransportFlags TransportIB6 = 0b10000000; +const TransportFlags TransportIB7 = 0b100000000; +const TransportFlags TransportAll = 0b111111111; + +class Communicator; + +class RegisteredMemory { + struct Impl; + std::shared_ptr pimpl; +public: -struct ConnectionEpoch { -#ifdef __CUDACC__ - __forceinline__ __device__ void wait() - { - (*waitEpochId) += 1; - while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) - ; - } + RegisteredMemory(std::shared_ptr pimpl); + ~RegisteredMemory(); - __forceinline__ __device__ void epochIncrement() - { - *(volatile uint64_t*)&(localSignalEpochId->device) += 1; - } -#endif // __CUDACC__ + void* data(); + size_t size(); + TransportFlags transports(); - SignalEpochId* localSignalEpochId; - // used by the signal() function directly from gpu - SignalEpochId* remoteSignalEpochId; + std::vector serialize(); + static RegisteredMemory deserialize(const std::vector& data); - // every wait(), increments this and then the gpu waits for either: - // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread - // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread - uint64_t* waitEpochId; + int rank(); + bool isLocal(); + bool isRemote(); }; -class HostConnection { +class Connection { struct Impl; + std::unique_ptr pimpl; public: - /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ - HostConnection(std::unique_ptr); - - ~HostConnection(); - - int getId(); - - /* Register a region of GPU memory for use with this connection. Must be called before connectionSetup() - * in the communicator. - * - * Inputs: - * data: base pointer to the memory - * size: size of the memory region in bytes - * - * Returns: a handle to the buffer - */ - BufferHandle registerBuffer(void* data, uint64_t size); - - /* Get the number of times registerBuffer(...) was called. - * - * Returns: the number of buffers registered - */ - int numLocalBuffers(); - - /* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index - * - * Inputs: - * index: the index of the handle to get - * - * Returns: a handle to the buffer - */ - BufferHandle getLocalBuffer(int index); - - /* Get the number of times registerBuffer(...) was called on the remote peer. - * - * Returns: the number of buffers registered on the remote peer - */ - int numRemoteBuffers(); - - /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index - * - * Inputs: - * index: the index of the handle to get - * - * Returns: a handle to the buffer on the remote peer - */ - BufferHandle getRemoteBuffer(int index); - ConnectionEpoch getEpoch(); + /* Connection can not be constructed from user code and must instead be created through Communicator::connect */ + Connection(std::unique_ptr); + ~Connection(); - DeviceProxyFifo getDeviceFifo(); - - void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); - - void signal(); + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); void flush(); - void wait(); - -private: - std::unique_ptr pimpl; - friend class Communicator; -}; - -/*************************************************************************************************************** - * A mscclppDevConn provides a zero-copy connection between two GPUs connected via P2P NVLink or InfiniBand. - * The communication API is one-sided meaning that for every single data transfer, only one side - * needs to execute unlike a two-sided communication stack such as NCCL where both sides - * need to execute a send and a receive instruction, respectively, for every transfer. - * - * A connection is uniquely identified by the (remoteRank, tag) pair at an endpoint. - * The two endpoints register buffers of the same size with the connection. - * - * The endpoints provide the remoteRank, tag, and the buffer when registering a connection with msccppConnect(). - * - * mscllppConnectionSetup() sets up all the registered connections. - * - *************************************************************************************************************** - * A proxy thread running on the CPU is necessary to perform transfers using InfiniBand or the DMA engine. - * The current implementation uses a single proxy thread per context - one IB connection or DMA engine per node. - * Thus multiple threadblocks using different connections might use the same CPU proxy thread. - * - * Before using any of functionality of connections, mscclppProxyLaunch needs to be called to spawn the - * proxy threads. There are currently two types of connections: - * - * P2P via NVLink: the DMA engine can perform the copy between the buffers. DMA engine has higher latency - * but has a higher bandwidth and costs no compute cycles on the GPU. - * - * InfiniBand: the RDMA engine copies the data over MLX devices. - * - *************************************************************************************************************** - * At the runtime, a GPU kernel has access to a mscclppDevConn object that provides the following functions: - * - * put(): [non-blocking] the sender initiates a data transfer to the receiver. - * - * signal(): [non-blocking] the sender signals the receiver that data is ready to be consumed. - * - * flush(): [blocking] the sender waits for all the data transfers to complete - * - * wait(): [blocking] the reciever waits on the signal() to start reading the data. - * - * The sender should not reuse the buffer till the flush() returns. - * The receiver should only access the data after the wait() returns. - * - * putWithSignal(): the sender initiates a data transfer and signals the receiver that data is ready to be consumed. - * This is an optimized version of a put() followed by a signal(). - * - * These functions hide the complexity of syncrhonization between the two GPUs and the CPU proxy thread. - * Example: - * - * // sender GPU - * devConn.put(data1) - * // not OK to write to data1 - * devConn.put(data2) - * // not OK to write to data1, data2 - * devConn.put(data3) // receiver GPU - * // not OK to write to data1, data2, data3 // not OK to read data1, data2, data3 - * devConn.signal() -------------------------------> devConn.wait() - * // not OK to write to data1, data2, data3 // OK to read data1, data2, data3 - * devConn.flush() - * // OK to write to data1, data2, data3 - * - * - * The two endpoint can concurrently use the same connection provided they are writing (puts) on different - * indices in the registered buffer. - **************************************************************************************************************/ -struct DeviceConnection { - DeviceConnection() = default; - - DeviceConnection(HostConnection& hostConn) - : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), - fifo(hostConn.getDeviceFifo()) {} - - DeviceConnection(const DeviceConnection& other) = default; - - DeviceConnection& operator=(DeviceConnection& other) = default; + TransportFlags transport(); + TransportFlags remoteTransport(); // Good to have because different IB transports can still connect to each other -#ifdef __CUDACC__ - __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) - { - fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value); - } + // template void write(RegisteredPtr dst, RegisteredPtr src, uint64_t size) { + // write(dst.memory(), dst.offset() * sizeof(T), src.memory(), src.offset() * sizeof(T), size); + // } - __forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) - { - put(dst, offset, src, offset, size); - } - - __forceinline__ __device__ void signal() - { - epochIncrement(); - fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value); - } - - __forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) - { - epochIncrement(); - fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value); - } - - __forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) - { - putWithSignal(dst, offset, src, offset, size); - } - - __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) - { - epochIncrement(); - uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value); - while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && - *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) - ; - } - - __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) - { - putWithSignalAndFlush(dst, offset, src, offset, size); - } - - __forceinline__ __device__ void flush() - { - uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value); - // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail - // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. - while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && - *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) - ; - } - - __forceinline__ __device__ void wait() - { - epoch.wait(); - } - - __forceinline__ __device__ void epochIncrement() - { - epoch.epochIncrement(); - } -#endif // __CUDACC__ - - int connectionId; - - ConnectionEpoch epoch; - - // this is a concurrent fifo which is multiple threads from the device - // can produce for and the sole proxy thread consumes it. - DeviceProxyFifo fifo; -}; - -struct SimpleDeviceConnection { - SimpleDeviceConnection() = default; - - SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) { - dst = hostConn.getRemoteBuffer(0); - src = hostConn.getLocalBuffer(0); - } - - SimpleDeviceConnection(const SimpleDeviceConnection& other) = default; - - SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default; - -#ifdef __CUDACC__ - - __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) - { - devConn.put(dst, dstOffset, src, srcOffset, size); - } - - __forceinline__ __device__ void put(uint64_t offset, uint64_t size) - { - put(offset, offset, size); - } - - __forceinline__ __device__ void signal() - { - devConn.signal(); - } - - __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) - { - devConn.putWithSignal(dst, dstOffset, src, srcOffset, size); - } - - __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) - { - putWithSignal(offset, offset, size); - } - - __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) - { - devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size); - } - - __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) - { - putWithSignalAndFlush(offset, offset, size); - } - - __forceinline__ __device__ void flush() - { - devConn.flush(); - } - - __forceinline__ __device__ void wait() - { - devConn.wait(); - } - - __forceinline__ __device__ void epochIncrement() - { - devConn.epochIncrement(); - } - -#endif // __CUDACC__ - - DeviceConnection devConn; - BufferHandle dst; - BufferHandle src; -}; - -#define MSCCLPP_UNIQUE_ID_BYTES 128 -struct UniqueId { - char internal[MSCCLPP_UNIQUE_ID_BYTES]; -}; - -/* Create a unique ID for communication. Only needs to be called by one process. - * Use with mscclppCommInitRankFromId(). - * All processes need to provide the same ID to mscclppCommInitRankFromId(). - * - * Outputs: - * uniqueId: the unique ID to be created - */ -std::unique_ptr getUniqueId(); - -/* Transport Types */ -enum class TransportType : uint8_t { - P2P = 0, - IB = 1, + friend class Communicator; }; class Communicator { + struct Impl; + std::unique_ptr pimpl; public: /* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function. @@ -436,6 +119,16 @@ class Communicator { /* A no-op function that is used to synchronize all processes via a bootstrap allgather*/ void bootstrapBarrier(); + /* Register a region of GPU memory for use in this communicator. + * + * Inputs: + * data: base pointer to the memory + * size: size of the memory region in bytes + * + * Returns: a handle to the buffer + */ + RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); + /* Connect to a remote rank. This function only prepares metadata for connection. The actual connection * is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection * from rank i to remote rank j needs to have a counterpart from rank j to rank i. @@ -450,19 +143,8 @@ class Communicator { * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. */ - std::shared_ptr connect(int remoteRank, int tag, TransportType transportType, const char* ibDev = 0); - - /* Establish all connections created by mscclppConnect(). This function must be called after all mscclppConnect() - * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. - */ - void connectionSetup(); - - /* Launch proxy thread(s). This function is supposed to be called before starting a kernel that uses DeviceConnection. */ - void startProxying(); + std::shared_ptr connect(int remoteRank, int tag, TransportFlags transport); - /* Stop proxy thread(s). */ - void stopProxying(); - /* Return the rank of the calling process. * * Outputs: @@ -476,37 +158,6 @@ class Communicator { * size: the number of ranks of the communicator */ int size(); - - struct Impl; -private: - std::unique_ptr pimpl; - friend class HostConnection; -}; - -enum class ProxyHandlerResult { - Continue, - FlushFifoTailAndContinue, - Stop, -}; - -class Proxy; -using ProxyHandler = std::function; - -class Proxy { -public: - Proxy(ProxyHandler handler); - - ~Proxy(); - - void start(); - - void stop(); - - HostProxyFifo& fifo(); - -private: - struct Impl; - std::unique_ptr pimpl; }; } // namespace mscclpp diff --git a/src/include/proxy.hpp b/src/include/proxy.hpp new file mode 100644 index 000000000..70b6ba493 --- /dev/null +++ b/src/include/proxy.hpp @@ -0,0 +1,39 @@ +#ifndef MSCCLPP_PROXY_HPP_ +#define MSCCLPP_PROXY_HPP_ + +#include + +#include +#include + +namespace mscclpp { + +enum class ProxyHandlerResult { + Continue, + FlushFifoTailAndContinue, + Stop, +}; + +class Proxy; +using ProxyHandler = std::function; + +class Proxy { +public: + Proxy(ProxyHandler handler); + + ~Proxy(); + + void start(); + + void stop(); + + HostProxyFifo& fifo(); + +private: + struct Impl; + std::unique_ptr pimpl; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_PROXY_HPP_ \ No newline at end of file diff --git a/src/include/registered_ptr.hpp b/src/include/registered_ptr.hpp new file mode 100644 index 000000000..7eadb6b0f --- /dev/null +++ b/src/include/registered_ptr.hpp @@ -0,0 +1,40 @@ +#ifndef MSCCLPP_REGISTERED_PTR_HPP_ +#define MSCCLPP_REGISTERED_PTR_HPP_ + +namespace mscclpp { + +template +class RegisteredPtr { + RegisteredMemory memory; + size_t offset; +public: + RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset) {} + RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0) {} + ~RegisteredPtr() {} + + RegisteredMemory memory() { + return memory; + } + + T* data() { + return reinterpret_cast(memory.data()); + } + + size_t size() { + return memory.size() / sizeof(T); + } + + size_t offset() { + return offset; + } + + RegisteredPtr operator+(size_t offset) { + return RegisteredPtr(memory, this->offset + offset); + } + + // TODO: all other relevant overloads +}; + +} // namespace mscclpp + +#endif // MSCCLPP_REGISTERED_PTR_HPP_ \ No newline at end of file From 35ade686ff502386f1bf09640ac12c31a26a8e8d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sun, 23 Apr 2023 14:47:07 +0000 Subject: [PATCH 02/54] IB in cpp style WIP --- src/communicator.cc | 16 +-- src/ib.cc | 191 +++++++++++++++++++++++++++-------- src/include/channel.hpp | 6 +- src/include/communicator.hpp | 4 +- src/include/ib.hpp | 61 +++++++++++ 5 files changed, 226 insertions(+), 52 deletions(-) create mode 100644 src/include/ib.hpp diff --git a/src/communicator.cc b/src/communicator.cc index 5a843c789..d12b20e43 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -1,3 +1,4 @@ +#include "mscclpp.hpp" #include "communicator.hpp" #include "host_connection.hpp" #include "comm.h" @@ -16,14 +17,14 @@ Communicator::Impl::~Impl() { MSCCLPP_API_CPP Communicator::~Communicator() = default; -mscclppTransport_t transportTypeToCStyle(TransportType type) { - switch (type) { - case TransportType::IB: +static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) { + switch (flags) { + case TransportIB: return mscclppTransportIB; - case TransportType::P2P: + case TransportCudaIpc: return mscclppTransportP2P; default: - throw std::runtime_error("Unknown transport type"); + throw std::runtime_error("Unsupported conversion"); } } @@ -45,9 +46,8 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { mscclppBootstrapBarrier(pimpl->comm); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, - TransportType transportType, const char* ibDev) { - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportTypeToCStyle(transportType), ibDev); +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) { + mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev); auto connIdx = pimpl->connections.size(); auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); pimpl->connections.push_back(conn); diff --git a/src/ib.cc b/src/ib.cc index bb574e21d..4a0947619 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -9,48 +10,8 @@ #include "comm.h" #include "debug.h" #include "ib.h" - -static int getIbDevNumaNode(const char* ibDevPath) -{ - if (ibDevPath == NULL) { - WARN("ibDevPath is NULL"); - return -1; - } - const char* postfix = "/device/numa_node"; - FILE* fp = NULL; - char* filePath = NULL; - int node = -1; - int res; - if (mscclppCalloc(&filePath, strlen(ibDevPath) + strlen(postfix) + 1) != mscclppSuccess) { - WARN("mscclppCalloc failed"); - goto exit; - } - memcpy(filePath, ibDevPath, strlen(ibDevPath) * sizeof(char)); - filePath[strlen(ibDevPath)] = '\0'; - if (strncat(filePath, postfix, strlen(postfix)) == NULL) { - WARN("strncat failed"); - goto exit; - } - fp = fopen(filePath, "r"); - if (fp == NULL) { - WARN("fopen failed (errno %d, path %s)", errno, filePath); - goto exit; - } - res = fscanf(fp, "%d", &node); - if (res != 1) { - WARN("fscanf failed (errno %d, path %s)", errno, filePath); - node = -1; - goto exit; - } -exit: - if (filePath != NULL) { - free(filePath); - } - if (fp != NULL) { - fclose(fp); - } - return node; -} +#include "ib.hpp" +#include "checks.hpp" mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName) { @@ -400,3 +361,149 @@ int mscclppIbQp::pollCq() { return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs); } + +namespace mscclpp { + +IbQp::IbQp(void* ctx, void* pd, int port) +{ + struct ibv_context* _ctx = static_cast(ctx); + struct ibv_pd* _pd = static_cast(pd); + + this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); + if (this->cq == nullptr) { + std::stringstream err; + err << "ibv_create_cq failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + + struct ibv_qp_init_attr qpInitAttr; + std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); + qpInitAttr.sq_sig_all = 0; + qpInitAttr.send_cq = static_cast(this->cq); + qpInitAttr.recv_cq = static_cast(this->cq); + qpInitAttr.qp_type = IBV_QPT_RC; + qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_inline_data = 0; + + struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr); + if (_qp == nullptr) { + std::stringstream err; + err << "ibv_create_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + + struct ibv_port_attr portAttr; + if (ibv_query_port(_ctx, port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->info.lid = portAttr.lid; + this->info.port = port; + this->info.linkLayer = portAttr.link_layer; + this->info.qpn = _qp->qp_num; + this->info.mtu = portAttr.active_mtu; + if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) { + union ibv_gid gid; + if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { + std::stringstream err; + err << "ibv_query_gid failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->info.spn = gid.global.subnet_prefix; + } + + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(qpAttr)); + qpAttr.qp_state = IBV_QPS_INIT; + qpAttr.pkey_index = 0; + qpAttr.port_num = port; + qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + this->qp = _qp; +} + +IbCtx::IbCtx(const std::string& ibDevName) +{ + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (std::string(devices[i]->name) == ibDevName) { + this->ctx = ibv_open_device(devices[i]); + break; + } + } + ibv_free_device_list(devices); + if (this->ctx == nullptr) { + std::stringstream err; + err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")"; + throw std::runtime_error(err.str()); + } + this->pd = ibv_alloc_pd(static_cast(this->ctx)); + if (this->pd == nullptr) { + std::stringstream err; + err << "ibv_alloc_pd failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } +} + +IbCtx::~IbCtx() +{ + if (this->pd != nullptr) { + ibv_dealloc_pd(static_cast(this->pd)); + } + if (this->ctx != nullptr) { + ibv_close_device(static_cast(this->ctx)); + } +} + +bool IbCtx::isPortUsable(int port) const +{ + struct ibv_port_attr portAttr; + if (ibv_query_port(static_cast(this->ctx), port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; + throw std::runtime_error(err.str()); + } + return portAttr.state == IBV_PORT_ACTIVE && (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || + portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND); +} + +int IbCtx::getAnyActivePort() const +{ + struct ibv_device_attr devAttr; + if (ibv_query_device(static_cast(this->ctx), &devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } + for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { + if (this->isPortUsable(port)) { + return port; + } + } + return -1; +} + +IbQp* IbCtx::createQp(int port /*=-1*/) +{ + if (port == -1) { + port = this->getAnyActivePort(); + if (port == -1) { + throw std::runtime_error("No active port found"); + } + } else if (!this->isPortUsable(port)) { + throw std::runtime_error("invalid IB port: " + std::to_string(port)); + } + qps.emplace_back(new IbQp(this->ctx, this->pd, port)); + return qps.back().get(); +} + +} // namespace mscclpp diff --git a/src/include/channel.hpp b/src/include/channel.hpp index cb1931b07..10a5f6016 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -2,6 +2,7 @@ #define MSCCLPP_CHANNEL_HPP_ #include "mscclpp.hpp" +#include "epoch.hpp" #include "proxy.hpp" namespace mscclpp { @@ -88,7 +89,7 @@ class HostConnection { ~HostConnection(); - void write() + void write(); int getId(); @@ -293,3 +294,6 @@ struct SimpleDeviceConnection { BufferHandle src; }; +} // namespace mscclpp + +#endif // MSCCLPP_CHANNEL_HPP_ diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 8294eeb6e..f2816c1aa 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -3,6 +3,8 @@ #include "mscclpp.hpp" #include "mscclpp.h" +#include "channel.hpp" +#include "proxy.hpp" namespace mscclpp { @@ -20,4 +22,4 @@ struct Communicator::Impl { } // namespace mscclpp -#endif \ No newline at end of file +#endif // MSCCL_COMMUNICATOR_HPP_ diff --git a/src/include/ib.hpp b/src/include/ib.hpp new file mode 100644 index 000000000..4c58cfdca --- /dev/null +++ b/src/include/ib.hpp @@ -0,0 +1,61 @@ +#ifndef MSCCLPP_IB_HPP_ +#define MSCCLPP_IB_HPP_ + +#include +#include +#include + +namespace mscclpp { + +// QP info to be shared with the remote peer +struct IbQpInfo +{ + uint16_t lid; + uint8_t port; + uint8_t linkLayer; + uint32_t qpn; + uint64_t spn; + uint32_t mtu; +}; + +class IbQp +{ +public: + ~IbQp(); + + IbQpInfo info; + +private: + IbQp(void* ctx, void* pd, int port); + + void* qp; + void* cq; + void* wcs; + void* wrs; + void* sges; + int wrn; + + friend class IbCtx; +}; + + +class IbCtx +{ +public: + IbCtx(const std::string& ibDevName); + ~IbCtx(); + + IbQp* createQp(int port = -1); + +private: + bool IbCtx::isPortUsable(int port) const; + int IbCtx::getAnyActivePort() const; + + void* ctx; + void* pd; + std::list> qps; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_IB_HPP_ From e4ee2eba25de399e4242b5ee9fd9f607b1b40e88 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 25 Apr 2023 00:41:45 +0000 Subject: [PATCH 03/54] WIP Connection in C++ --- src/communicator.cc | 38 ++++++++++++++-------- src/connection.cc | 54 +++++++++++++++++++++++++++++++ src/include/communicator.hpp | 12 +++---- src/include/connection.hpp | 48 +++++++++++++++++++++++++++ src/include/mscclpp.hpp | 38 +++++++++++----------- src/include/registered_memory.hpp | 46 ++++++++++++++++++++++++++ src/registered_memory.cc | 7 ++++ 7 files changed, 205 insertions(+), 38 deletions(-) create mode 100644 src/connection.cc create mode 100644 src/include/connection.hpp create mode 100644 src/include/registered_memory.hpp create mode 100644 src/registered_memory.cc diff --git a/src/communicator.cc b/src/communicator.cc index d12b20e43..a74923bb3 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -17,9 +17,16 @@ Communicator::Impl::~Impl() { MSCCLPP_API_CPP Communicator::~Communicator() = default; -static mscclppTransport_t transportFlagsToCStyle(TransportFlags flags) { +static mscclppTransport_t transportToCStyle(TransportFlags flags) { switch (flags) { - case TransportIB: + case TransportIB0: + case TransportIB1: + case TransportIB2: + case TransportIB3: + case TransportIB4: + case TransportIB5: + case TransportIB6: + case TransportIB7: return mscclppTransportIB; case TransportCudaIpc: return mscclppTransportP2P; @@ -46,10 +53,23 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { mscclppBootstrapBarrier(pimpl->comm); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transportFlags, const char* ibDev) { - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportFlagsToCStyle(transportFlags), ibDev); +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { + std::string ibDev; + switch (transport) { + case TransportIB0: + case TransportIB1: + case TransportIB2: + case TransportIB3: + case TransportIB4: + case TransportIB5: + case TransportIB6: + case TransportIB7: + ibDev = getIBDeviceName(transport); + break; + } + mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str()); auto connIdx = pimpl->connections.size(); - auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); + auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); pimpl->connections.push_back(conn); return conn; } @@ -58,14 +78,6 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() { mscclppConnectionSetup(pimpl->comm); } -MSCCLPP_API_CPP void Communicator::startProxying() { - pimpl->proxy.start(); -} - -MSCCLPP_API_CPP void Communicator::stopProxying() { - pimpl->proxy.stop(); -} - MSCCLPP_API_CPP int Communicator::rank() { int result; mscclppCommRank(pimpl->comm, &result); diff --git a/src/connection.cc b/src/connection.cc new file mode 100644 index 000000000..12ebee027 --- /dev/null +++ b/src/connection.cc @@ -0,0 +1,54 @@ +#include "connection.hpp" +#include "checks.hpp" +#include "registered_memory.hpp" + +namespace mscclpp { + +void validateTransport(RegisteredMemory mem, TransportFlags transport) { + if (mem.transports() & transport == TransportNone) { + throw std::runtime_error("mem does not support transport"); + } +} + +TransportFlags CudaIpcConnection::transport() { + return TransportCudaIpc; +} + +TransportFlags CudaIpcConnection::remoteTransport() { + return TransportCudaIpc; +} + +CudaIpcConnection::CudaIpcConnection() { + cudaStreamCreate(&stream); +} + +CudaIpcConnection::~CudaIpcConnection() { + cudaStreamDestroy(stream); +} + +void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + auto dstPtr = dst.impl->getTransportData(remoteTransport()); + auto srcPtr = src.impl->getTransportData(transport()); + CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); + npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize); +} + +void CudaIpcConnection::flush() { + CUDATHROW(cudaStreamSynchronize(stream)); + npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); +} + +IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {} + +TransportFlags IBConnection::transport() { + return transport_; +} + +TransportFlags IBConnection::remoteTransport() { + return remoteTransport_; +} + +} // namespace mscclpp diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index f2816c1aa..827b02814 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -9,15 +9,15 @@ namespace mscclpp { struct Communicator::Impl { - mscclppComm_t comm; - std::vector> connections; - Proxy proxy; + mscclppComm_t comm; + std::vector> connections; + Proxy proxy; - Impl(); + Impl(); - ~Impl(); + ~Impl(); - friend class HostConnection; + friend class Connection; }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp new file mode 100644 index 000000000..048e2c6ac --- /dev/null +++ b/src/include/connection.hpp @@ -0,0 +1,48 @@ +#ifndef MSCCLPP_CONNECTION_HPP_ +#define MSCCLPP_CONNECTION_HPP_ + +#include "mscclpp.hpp" +#include +#include "ib.h" + +namespace mscclpp { + +class CudaIpcConnection : public Connection { + cudaStream_t stream; +public: + + CudaIpcConnection(); + + virtual ~CudaIpcConnection(); + + virtual TransportFlags transport(); + + virtual TransportFlags remoteTransport(); + + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + + virtual void flush(); +}; + +class IBConnection : public Connection { + TransportFlags transport_; + TransportFlags remoteTransport_; + mscclppIbQp qp; +public: + + IBConnection(TransportFlags transport); + + virtual ~IBConnection(); + + virtual TransportFlags transport(); + + virtual TransportFlags remoteTransport(); + + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + + virtual void flush(); +}; + +} // namespace mscclpp + +#endif // MSCCLPP_CONNECTION_HPP_ diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 67d400508..f4d73ab4a 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -26,8 +26,9 @@ struct UniqueId { std::unique_ptr getUniqueId(); using TransportFlags = uint32_t; +const TransportFlags TransportNone = 0b0; const TransportFlags TransportCudaIpc = 0b1; -const TransportFlags TransportIB = 0b10; +const TransportFlags TransportIB0 = 0b10; const TransportFlags TransportIB1 = 0b100; const TransportFlags TransportIB2 = 0b1000; const TransportFlags TransportIB3 = 0b10000; @@ -37,7 +38,12 @@ const TransportFlags TransportIB6 = 0b10000000; const TransportFlags TransportIB7 = 0b100000000; const TransportFlags TransportAll = 0b111111111; +int getIBDeviceCount(); +std::string getIBDeviceName(TransportFlags ibTransport); +TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName); + class Communicator; +class Connection; class RegisteredMemory { struct Impl; @@ -55,31 +61,20 @@ class RegisteredMemory { static RegisteredMemory deserialize(const std::vector& data); int rank(); - bool isLocal(); - bool isRemote(); + + friend class Connection; }; class Connection { - struct Impl; - std::unique_ptr pimpl; -public: + virtual ~Connection() = 0; - /* Connection can not be constructed from user code and must instead be created through Communicator::connect */ - Connection(std::unique_ptr); - ~Connection(); + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; - void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + virtual void flush() = 0; - void flush(); + virtual TransportFlags transport() = 0; - TransportFlags transport(); - TransportFlags remoteTransport(); // Good to have because different IB transports can still connect to each other - - // template void write(RegisteredPtr dst, RegisteredPtr src, uint64_t size) { - // write(dst.memory(), dst.offset() * sizeof(T), src.memory(), src.offset() * sizeof(T), size); - // } - - friend class Communicator; + virtual TransportFlags remoteTransport() = 0; }; class Communicator { @@ -145,6 +140,11 @@ class Communicator { */ std::shared_ptr connect(int remoteRank, int tag, TransportFlags transport); + /* Establish all connections declared by connect(). This function must be called after all connect() + * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. + */ + void connectionSetup(); + /* Return the rank of the calling process. * * Outputs: diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp new file mode 100644 index 000000000..82fe942e4 --- /dev/null +++ b/src/include/registered_memory.hpp @@ -0,0 +1,46 @@ +#ifndef MSCCLPP_REGISTERED_MEMORY_HPP_ +#define MSCCLPP_REGISTERED_MEMORY_HPP_ + +#include "mscclpp.hpp" +#include "ib.h" +#include +#include + +namespace mscclpp { + +struct IBTransportData { + mscclppIbMr localIbMr; + mscclppIbMrInfo remoteIbMrInfo; +}; + +struct TransportData { + TransportFlags transport; + union { + void* cudaIpcPtr; + IBTransportData ibData; + } +}; + +struct RegisteredMemory::Impl { + void* data; + size_t size; + TransportFlags transports; + std::vector transportData; + + Impl(void* data, size_t size, TransportFlags transports); + + ~Impl(); + + template T& getTransportData(TransportFlags transport) { + for (auto& data : transportData) { + if (data.transport == transport) { + return data; + } + } + throw std::runtime_error("Transport data not found"); + } +}; + +} // namespace mscclpp + +#endif // MSCCLPP_REGISTERED_MEMORY_HPP_ diff --git a/src/registered_memory.cc b/src/registered_memory.cc new file mode 100644 index 000000000..d491e72f2 --- /dev/null +++ b/src/registered_memory.cc @@ -0,0 +1,7 @@ +#include "registered_memory.hpp" + +namespace mscclpp { + + + +} // namespace mscclpp From 90a8860bcc45624a945dca33a480206fc861c41d Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 26 Apr 2023 03:04:56 +0000 Subject: [PATCH 04/54] Registered memory (de)serialization and Connection work --- src/connection.cc | 66 +++++++++++++++- src/include/connection.hpp | 6 +- src/include/mscclpp.hpp | 5 +- src/include/registered_memory.hpp | 29 +++---- src/registered_memory.cc | 121 ++++++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 26 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 12ebee027..48b2d1973 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,6 +1,7 @@ #include "connection.hpp" #include "checks.hpp" #include "registered_memory.hpp" +#include "npkit.h" namespace mscclpp { @@ -10,6 +11,8 @@ void validateTransport(RegisteredMemory mem, TransportFlags transport) { } } +// CudaIpcConnection + TransportFlags CudaIpcConnection::transport() { return TransportCudaIpc; } @@ -30,17 +33,20 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstPtr = dst.impl->getTransportData(remoteTransport()); - auto srcPtr = src.impl->getTransportData(transport()); + auto dstPtr = dst.impl->data; + auto srcPtr = src.impl->data; + CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); - npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)dataSize); + // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); } void CudaIpcConnection::flush() { CUDATHROW(cudaStreamSynchronize(stream)); - npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); + // npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); } +// IBConnection + IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {} TransportFlags IBConnection::transport() { @@ -51,4 +57,56 @@ TransportFlags IBConnection::remoteTransport() { return remoteTransport_; } +IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) { + MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); +} + +IBConnection::~IBConnection() { + // TODO: Destroy QP? +} + +void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + auto dstMrInfo = dst.impl->getTransportInfo(remoteTransport()); + auto srcMr = src.impl->getTransportInfo(transport()); + + qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size, + /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); + int ret = qp->postSend(); + if (ret != 0) { + // Return value is errno. + WARN("data postSend failed: errno %d", ret); + } + // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); +} + +void IBConnection::flush() { + bool isWaiting = true; + while (isWaiting) { + int wcNum = qp->pollCq(); + if (wcNum < 0) { + WARN("pollCq failed: errno %d", errno); + continue; + } + for (int i = 0; i < wcNum; ++i) { + struct ibv_wc* wc = &qp->wcs[i]; + if (wc->status != IBV_WC_SUCCESS) { + WARN("wc status %d", wc->status); + continue; + } + if (wc->qp_num != qp->qp->qp_num) { + WARN("got wc of unknown qp_num %d", wc->qp_num); + continue; + } + if (wc->opcode == IBV_WC_RDMA_WRITE) { + isWaiting = false; + break; + } + } + } + // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); +} + } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 048e2c6ac..72f0eb90d 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -7,6 +7,8 @@ namespace mscclpp { +// TODO: Add functionality to these classes for Communicator to do connectionSetup + class CudaIpcConnection : public Connection { cudaStream_t stream; public: @@ -27,10 +29,10 @@ class CudaIpcConnection : public Connection { class IBConnection : public Connection { TransportFlags transport_; TransportFlags remoteTransport_; - mscclppIbQp qp; + mscclppIbQp* qp; public: - IBConnection(TransportFlags transport); + IBConnection(TransportFlags transport, Communicator::Impl& commImpl); virtual ~IBConnection(); diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index f4d73ab4a..52b0511bf 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -36,7 +36,9 @@ const TransportFlags TransportIB4 = 0b100000; const TransportFlags TransportIB5 = 0b1000000; const TransportFlags TransportIB6 = 0b10000000; const TransportFlags TransportIB7 = 0b100000000; + const TransportFlags TransportAll = 0b111111111; +const TransportFlags TransportAllIB = 0b111111110; int getIBDeviceCount(); std::string getIBDeviceName(TransportFlags ibTransport); @@ -55,13 +57,12 @@ class RegisteredMemory { void* data(); size_t size(); + int rank(); TransportFlags transports(); std::vector serialize(); static RegisteredMemory deserialize(const std::vector& data); - int rank(); - friend class Connection; }; diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 82fe942e4..24eed981d 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -2,39 +2,32 @@ #define MSCCLPP_REGISTERED_MEMORY_HPP_ #include "mscclpp.hpp" +#include "mscclpp.h" #include "ib.h" #include #include namespace mscclpp { -struct IBTransportData { - mscclppIbMr localIbMr; - mscclppIbMrInfo remoteIbMrInfo; -}; - -struct TransportData { +struct TransportInfo { TransportFlags transport; - union { - void* cudaIpcPtr; - IBTransportData ibData; - } + std::variant data; }; struct RegisteredMemory::Impl { void* data; size_t size; + int rank; TransportFlags transports; - std::vector transportData; - - Impl(void* data, size_t size, TransportFlags transports); + std::vector transportInfos; - ~Impl(); + Impl(void* data, size_t size, int rank, TransportFlags transports); + Impl(const std::vector& data); - template T& getTransportData(TransportFlags transport) { - for (auto& data : transportData) { - if (data.transport == transport) { - return data; + template T& getTransportInfo(TransportFlags transport) { + for (auto& entry : transportInfos) { + if (entry.transport == transport) { + return std::get(entry.data); } } throw std::runtime_error("Transport data not found"); diff --git a/src/registered_memory.cc b/src/registered_memory.cc index d491e72f2..eabb9e7d8 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -2,6 +2,127 @@ namespace mscclpp { +RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator& comm) : data(data), size(size), rank(rank), transports(transports) { + if (transports & TransportCudaIpc) { + TransportInfo transportInfo; + transportInfo.transport = TransportCudaIpc; + cudaIpcMemHandle_t handle; + CUDATHROW(cudaIpcGetMemHandle(&handle, data)); + transportInfo.data = handle; + this->transportInfos.push_back(transportInfo); + } + if (transports & TransportAllIB) { + auto addIb = [&](TransportFlags ibTransport) { + TransportInfo transportInfo; + transportInfo.transport = ibTransport; + mscclppIbMr* mr; + MSCCLPPTHROW(mscclppIbContextRegisterMr(comm.pimpl->getIbContext(ibTransport), data, size, &mr)); + transportInfo.data = mr; + this->transportInfos.push_back(transportInfo); + }; + if (transports & TransportIB0) addIb(TransportIB0); + if (transports & TransportIB1) addIb(TransportIB1); + if (transports & TransportIB2) addIb(TransportIB2); + if (transports & TransportIB3) addIb(TransportIB3); + if (transports & TransportIB4) addIb(TransportIB4); + if (transports & TransportIB5) addIb(TransportIB5); + if (transports & TransportIB6) addIb(TransportIB6); + if (transports & TransportIB7) addIb(TransportIB7); + } +} +RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : impl(pimpl) {} + +RegisteredMemory::~RegisteredMemory() = default; + +void* RegisteredMemory::data() { + return impl->data; +} + +size_t RegisteredMemory::size() { + return impl->size; +} + +int RegisteredMemory::rank() { + return impl->rank; +} + +TransportFlags RegisteredMemory::transports() { + return impl->transports; +} + +std::vector RegisteredMemory::serialize() { + std::vector result; + std::copy_n(reinterpret_cast(&impl->size), sizeof(impl->size), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&impl->rank), sizeof(impl->rank), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&impl->transports), sizeof(impl->transports), std::back_inserter(result)); + if (impl->transportInfos.size() > std::numeric_limits::max()) { + throw std::runtime_error("Too many transport info entries"); + } + int8_t transportCount = impl->transportInfos.size(); + std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); + for (auto& entry : impl->transportInfos) { + std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); + std::visit(overloaded{ + [&](std::monostate&){ + throw std::runtime_error("Transport info not set"); + }, + [&](cudaIpcMemHandle_t handle){ + std::copy_n(reinterpret_cast(&handle), sizeof(handle), std::back_inserter(result)); + }, + [&](mscclppIbMr* mr){ + std::copy_n(reinterpret_cast(&mr->info), sizeof(mr->info), std::back_inserter(result)); + }, + [&](mscclppIbMrInfo info){ + std::copy_n(reinterpret_cast(&info), sizeof(info), std::back_inserter(result)); + } + }, entry.data); + } + return result; +} + +static RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { + return RegisteredMemory(std::make_shared(data)); +} + +RegisteredMemory::Impl::Impl(const std::vector& data) { + auto it = data.begin(); + std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); + it += sizeof(this->size); + std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); + it += sizeof(this->rank); + std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); + it += sizeof(this->transports); + int8_t transportCount; + std::copy_n(it, sizeof(transportCount), reinterpret_cast(&transportCount)); + it += sizeof(transportCount); + for (int i = 0; i < transportCount; ++i) { + TransportInfo transportInfo; + std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); + it += sizeof(transportInfo.transport); + if (transportInfo.transport & TransportCudaIpc) { + cudaIpcMemHandle_t handle; + std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); + it += sizeof(handle); + transportInfo.data = handle; + } else if (transportInfo.transport & TransportAllIB) { + mscclppIbMrInfo info; + std::copy_n(it, sizeof(info), reinterpret_cast(&info)); + it += sizeof(info); + transportInfo.data = info; + } else { + throw std::runtime_error("Unknown transport"); + } + this->transportInfos.push_back(transportInfo); + } + if (it != data.end()) { + throw std::runtime_error("Deserialization failed"); + } + + if (transports & TransportCudaIpc) { + auto cudaIpcHandle = getTransportInfo(TransportCudaIpc); + CUDATHROW(cudaIpcOpenMemHandle(&data, cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + } +} } // namespace mscclpp From d746201287d63407ac110c5d6aae85a3aafddab2 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 26 Apr 2023 17:46:47 +0000 Subject: [PATCH 05/54] WIP builds, but doesn't link --- Makefile | 5 ++- src/communicator.cc | 50 ++++++++++++++-------- src/connection.cc | 53 +++++++++++++---------- src/include/communicator.hpp | 6 ++- src/include/connection.hpp | 21 ++++----- src/include/ib.hpp | 4 +- src/include/mscclpp.hpp | 13 +++--- src/include/registered_memory.hpp | 17 +++++--- src/registered_memory.cc | 71 +++++++++++++++---------------- 9 files changed, 136 insertions(+), 104 deletions(-) diff --git a/Makefile b/Makefile index e544aeee3..9aaf34b80 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) -LIBSRCS += $(addprefix src/,communicator.cc fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) +LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) +#LIBSRCS += $(addprefix src/,fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif @@ -148,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu allgather_test_cpp.cu) +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu) # allgather_test_cpp.cu TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/src/communicator.cc b/src/communicator.cc index a74923bb3..316801de5 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -4,17 +4,39 @@ #include "comm.h" #include "basic_proxy_handler.hpp" #include "api.h" +#include "utils.h" +#include "checks.hpp" +#include "debug.h" +#include "connection.hpp" namespace mscclpp { -Communicator::Impl::Impl() : comm(nullptr), proxy(makeBasicProxyHandler(*this)) {} +Communicator::Impl::Impl() : comm(nullptr) {} Communicator::Impl::~Impl() { + for (auto& entry : ibContexts) { + mscclppIbContextDestroy(entry.second); + } + ibContexts.clear(); if (comm) { mscclppCommDestroy(comm); } } +mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) { + // Find IB context or create it + auto it = ibContexts.find(ibTransport); + if (it == ibContexts.end()) { + auto ibDev = getIBDeviceName(ibTransport); + mscclppIbContext* ibCtx; + MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str())); + ibContexts[ibTransport] = ibCtx; + return ibCtx; + } else { + return it->second; + } +} + MSCCLPP_API_CPP Communicator::~Communicator() = default; static mscclppTransport_t transportToCStyle(TransportFlags flags) { @@ -54,24 +76,16 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { - std::string ibDev; - switch (transport) { - case TransportIB0: - case TransportIB1: - case TransportIB2: - case TransportIB3: - case TransportIB4: - case TransportIB5: - case TransportIB6: - case TransportIB7: - ibDev = getIBDeviceName(transport); - break; + std::shared_ptr conn; + if (transport | TransportCudaIpc) { + auto cudaIpcConn = std::make_shared(); + conn = cudaIpcConn; + } else if (transport | TransportAllIB) { + auto ibConn = std::make_shared(transport, *pimpl); + conn = ibConn; + } else { + throw std::runtime_error("Unsupported transport"); } - mscclppConnectWithoutBuffer(pimpl->comm, remoteRank, tag, transportToCStyle(transport), ibDev.c_str()); - auto connIdx = pimpl->connections.size(); - auto conn = std::make_shared(std::make_unique(this, &pimpl->comm->conns[connIdx])); - pimpl->connections.push_back(conn); - return conn; } MSCCLPP_API_CPP void Communicator::connectionSetup() { diff --git a/src/connection.cc b/src/connection.cc index 48b2d1973..3e053cb32 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,26 +1,18 @@ #include "connection.hpp" #include "checks.hpp" #include "registered_memory.hpp" -#include "npkit.h" +#include "npkit/npkit.h" namespace mscclpp { void validateTransport(RegisteredMemory mem, TransportFlags transport) { - if (mem.transports() & transport == TransportNone) { + if ((mem.transports() & transport) == TransportNone) { throw std::runtime_error("mem does not support transport"); } } // CudaIpcConnection -TransportFlags CudaIpcConnection::transport() { - return TransportCudaIpc; -} - -TransportFlags CudaIpcConnection::remoteTransport() { - return TransportCudaIpc; -} - CudaIpcConnection::CudaIpcConnection() { cudaStreamCreate(&stream); } @@ -29,12 +21,20 @@ CudaIpcConnection::~CudaIpcConnection() { cudaStreamDestroy(stream); } +TransportFlags CudaIpcConnection::transport() { + return TransportCudaIpc; +} + +TransportFlags CudaIpcConnection::remoteTransport() { + return TransportCudaIpc; +} + void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstPtr = dst.impl->data; - auto srcPtr = src.impl->data; + auto dstPtr = dst.data(); + auto srcPtr = src.data(); CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); @@ -47,7 +47,13 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(TransportFlags transport) : transport_(transport), remoteTransport_(TransportNone) {} +IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) { + MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); +} + +IBConnection::~IBConnection() { + // TODO: Destroy QP? +} TransportFlags IBConnection::transport() { return transport_; @@ -57,20 +63,21 @@ TransportFlags IBConnection::remoteTransport() { return remoteTransport_; } -IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) { - MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); -} - -IBConnection::~IBConnection() { - // TODO: Destroy QP? -} - void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstMrInfo = dst.impl->getTransportInfo(remoteTransport()); - auto srcMr = src.impl->getTransportInfo(transport()); + auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); + if (dstTransportInfo.ibLocal) { + throw std::runtime_error("dst is local, which is not supported"); + } + auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(remoteTransport()); + if (!srcTransportInfo.ibLocal) { + throw std::runtime_error("src is remote, which is not supported"); + } + + auto dstMrInfo = dstTransportInfo.ibMrInfo; + auto srcMr = srcTransportInfo.ibMr; qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 827b02814..8eb0e2026 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -5,19 +5,21 @@ #include "mscclpp.h" #include "channel.hpp" #include "proxy.hpp" +#include "ib.h" +#include namespace mscclpp { struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - Proxy proxy; + std::unordered_map ibContexts; Impl(); ~Impl(); - friend class Connection; + mscclppIbContext* getIbContext(TransportFlags ibTransport); }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 72f0eb90d..94d727e77 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -4,6 +4,7 @@ #include "mscclpp.hpp" #include #include "ib.h" +#include "communicator.hpp" namespace mscclpp { @@ -15,15 +16,15 @@ class CudaIpcConnection : public Connection { CudaIpcConnection(); - virtual ~CudaIpcConnection(); + ~CudaIpcConnection(); - virtual TransportFlags transport(); + TransportFlags transport() override; - virtual TransportFlags remoteTransport(); + TransportFlags remoteTransport() override; - virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; - virtual void flush(); + void flush() override; }; class IBConnection : public Connection { @@ -34,15 +35,15 @@ class IBConnection : public Connection { IBConnection(TransportFlags transport, Communicator::Impl& commImpl); - virtual ~IBConnection(); + ~IBConnection(); - virtual TransportFlags transport(); + TransportFlags transport() override; - virtual TransportFlags remoteTransport(); + TransportFlags remoteTransport() override; - virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size); + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; - virtual void flush(); + void flush() override; }; } // namespace mscclpp diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 4c58cfdca..85c92af78 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -48,8 +48,8 @@ class IbCtx IbQp* createQp(int port = -1); private: - bool IbCtx::isPortUsable(int port) const; - int IbCtx::getAnyActivePort() const; + bool isPortUsable(int port) const; + int getAnyActivePort() const; void* ctx; void* pd; diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 52b0511bf..9c699efb3 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -67,8 +67,7 @@ class RegisteredMemory { }; class Connection { - virtual ~Connection() = 0; - +public: virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; virtual void flush() = 0; @@ -76,13 +75,13 @@ class Connection { virtual TransportFlags transport() = 0; virtual TransportFlags remoteTransport() = 0; + +protected: + static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); }; class Communicator { - struct Impl; - std::unique_ptr pimpl; public: - /* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function. * * Inputs: @@ -159,6 +158,10 @@ class Communicator { * size: the number of ranks of the communicator */ int size(); + + struct Impl; +private: + std::unique_ptr pimpl; }; } // namespace mscclpp diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 24eed981d..7a0ab1d02 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -4,14 +4,21 @@ #include "mscclpp.hpp" #include "mscclpp.h" #include "ib.h" -#include +#include "communicator.hpp" #include namespace mscclpp { struct TransportInfo { TransportFlags transport; - std::variant data; + + // TODO: rewrite this using std::variant or something + bool ibLocal; + union { + cudaIpcMemHandle_t cudaIpcHandle; + mscclppIbMr* ibMr; + mscclppIbMrInfo ibMrInfo; + }; }; struct RegisteredMemory::Impl { @@ -21,13 +28,13 @@ struct RegisteredMemory::Impl { TransportFlags transports; std::vector transportInfos; - Impl(void* data, size_t size, int rank, TransportFlags transports); + Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(const std::vector& data); - template T& getTransportInfo(TransportFlags transport) { + TransportInfo& getTransportInfo(TransportFlags transport) { for (auto& entry : transportInfos) { if (entry.transport == transport) { - return std::get(entry.data); + return entry; } } throw std::runtime_error("Transport data not found"); diff --git a/src/registered_memory.cc b/src/registered_memory.cc index eabb9e7d8..7a5a0725d 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -1,14 +1,16 @@ #include "registered_memory.hpp" +#include "checks.hpp" +#include namespace mscclpp { -RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator& comm) : data(data), size(size), rank(rank), transports(transports) { +RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) { if (transports & TransportCudaIpc) { TransportInfo transportInfo; transportInfo.transport = TransportCudaIpc; cudaIpcMemHandle_t handle; CUDATHROW(cudaIpcGetMemHandle(&handle, data)); - transportInfo.data = handle; + transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); } if (transports & TransportAllIB) { @@ -16,8 +18,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t TransportInfo transportInfo; transportInfo.transport = ibTransport; mscclppIbMr* mr; - MSCCLPPTHROW(mscclppIbContextRegisterMr(comm.pimpl->getIbContext(ibTransport), data, size, &mr)); - transportInfo.data = mr; + MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr)); + transportInfo.ibMr = mr; + transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); }; if (transports & TransportIB0) addIb(TransportIB0); @@ -31,62 +34,55 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t } } -RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : impl(pimpl) {} +RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) {} RegisteredMemory::~RegisteredMemory() = default; void* RegisteredMemory::data() { - return impl->data; + return pimpl->data; } size_t RegisteredMemory::size() { - return impl->size; + return pimpl->size; } int RegisteredMemory::rank() { - return impl->rank; + return pimpl->rank; } TransportFlags RegisteredMemory::transports() { - return impl->transports; + return pimpl->transports; } std::vector RegisteredMemory::serialize() { std::vector result; - std::copy_n(reinterpret_cast(&impl->size), sizeof(impl->size), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&impl->rank), sizeof(impl->rank), std::back_inserter(result)); - std::copy_n(reinterpret_cast(&impl->transports), sizeof(impl->transports), std::back_inserter(result)); - if (impl->transportInfos.size() > std::numeric_limits::max()) { + std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); + if (pimpl->transportInfos.size() > std::numeric_limits::max()) { throw std::runtime_error("Too many transport info entries"); } - int8_t transportCount = impl->transportInfos.size(); + int8_t transportCount = pimpl->transportInfos.size(); std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); - for (auto& entry : impl->transportInfos) { + for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); - std::visit(overloaded{ - [&](std::monostate&){ - throw std::runtime_error("Transport info not set"); - }, - [&](cudaIpcMemHandle_t handle){ - std::copy_n(reinterpret_cast(&handle), sizeof(handle), std::back_inserter(result)); - }, - [&](mscclppIbMr* mr){ - std::copy_n(reinterpret_cast(&mr->info), sizeof(mr->info), std::back_inserter(result)); - }, - [&](mscclppIbMrInfo info){ - std::copy_n(reinterpret_cast(&info), sizeof(info), std::back_inserter(result)); - } - }, entry.data); + if (entry.transport == TransportCudaIpc) { + std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result)); + } else if (entry.transport & TransportAllIB) { + std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); + } else { + throw std::runtime_error("Unknown transport"); + } } return result; } -static RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { +RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { return RegisteredMemory(std::make_shared(data)); } -RegisteredMemory::Impl::Impl(const std::vector& data) { - auto it = data.begin(); +RegisteredMemory::Impl::Impl(const std::vector& serialization) { + auto it = serialization.begin(); std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); it += sizeof(this->size); std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); @@ -104,24 +100,25 @@ RegisteredMemory::Impl::Impl(const std::vector& data) { cudaIpcMemHandle_t handle; std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); it += sizeof(handle); - transportInfo.data = handle; + transportInfo.cudaIpcHandle = handle; } else if (transportInfo.transport & TransportAllIB) { mscclppIbMrInfo info; std::copy_n(it, sizeof(info), reinterpret_cast(&info)); it += sizeof(info); - transportInfo.data = info; + transportInfo.ibMrInfo = info; + transportInfo.ibLocal = false; } else { throw std::runtime_error("Unknown transport"); } this->transportInfos.push_back(transportInfo); } - if (it != data.end()) { + if (it != serialization.end()) { throw std::runtime_error("Deserialization failed"); } if (transports & TransportCudaIpc) { - auto cudaIpcHandle = getTransportInfo(TransportCudaIpc); - CUDATHROW(cudaIpcOpenMemHandle(&data, cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + auto entry = getTransportInfo(TransportCudaIpc); + CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); } } From 5443ed1ec22cedb7db7e0c55e9d80e555576ad06 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 26 Apr 2023 18:07:17 +0000 Subject: [PATCH 06/54] ConnectionSetup stuff --- src/communicator.cc | 12 +++++++++--- src/connection.cc | 14 +++++++++++--- src/include/communicator.hpp | 4 +++- src/include/connection.hpp | 18 +++++++++++++++--- 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 316801de5..9ce5b7791 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -76,20 +76,26 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { - std::shared_ptr conn; + std::shared_ptr conn; if (transport | TransportCudaIpc) { auto cudaIpcConn = std::make_shared(); conn = cudaIpcConn; } else if (transport | TransportAllIB) { - auto ibConn = std::make_shared(transport, *pimpl); + auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); conn = ibConn; } else { throw std::runtime_error("Unsupported transport"); } + pimpl->connections.push_back(conn); } MSCCLPP_API_CPP void Communicator::connectionSetup() { - mscclppConnectionSetup(pimpl->comm); + for (auto& conn : pimpl->connections) { + conn->startSetup(*this); + } + for (auto& conn : pimpl->connections) { + conn->endSetup(*this); + } } MSCCLPP_API_CPP int Communicator::rank() { diff --git a/src/connection.cc b/src/connection.cc index 3e053cb32..24482c7b8 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -33,8 +33,8 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register validateTransport(dst, remoteTransport()); validateTransport(src, transport()); - auto dstPtr = dst.data(); - auto srcPtr = src.data(); + char* dstPtr = (char*)dst.data(); + char* srcPtr = (char*)src.data(); CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); @@ -47,7 +47,7 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(TransportFlags transport, Communicator::Impl& commImpl) : transport_(transport), remoteTransport_(TransportNone) { +IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) { MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); } @@ -116,4 +116,12 @@ void IBConnection::flush() { // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } +void startSetup(Communicator& comm) { + // TODO: use bootstrapper from comm to send over QP info +} + +void endSetup(Communicator& comm) { + // TODO: use bootstrapper from comm to receive QP info and do the rtr/rts calls +} + } // namespace mscclpp diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 8eb0e2026..879501c0c 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -10,9 +10,11 @@ namespace mscclpp { +class ConnectionBase; + struct Communicator::Impl { mscclppComm_t comm; - std::vector> connections; + std::vector> connections; std::unordered_map ibContexts; Impl(); diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 94d727e77..ac1dd6a17 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -10,7 +10,13 @@ namespace mscclpp { // TODO: Add functionality to these classes for Communicator to do connectionSetup -class CudaIpcConnection : public Connection { +class ConnectionBase : public Connection { +public: + virtual void startSetup(Communicator&) {}; + virtual void endSetup(Communicator&) {}; +}; + +class CudaIpcConnection : public ConnectionBase { cudaStream_t stream; public: @@ -27,13 +33,15 @@ class CudaIpcConnection : public Connection { void flush() override; }; -class IBConnection : public Connection { +class IBConnection : public ConnectionBase { + int remoteRank; + int tag; TransportFlags transport_; TransportFlags remoteTransport_; mscclppIbQp* qp; public: - IBConnection(TransportFlags transport, Communicator::Impl& commImpl); + IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl); ~IBConnection(); @@ -44,6 +52,10 @@ class IBConnection : public Connection { void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; void flush() override; + + void startSetup(Communicator& comm) override; + + void endSetup(Communicator& comm) override; }; } // namespace mscclpp From 9c6e68525353ef0d4ea450b816fefd0c34a0f45b Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 26 Apr 2023 23:46:22 +0000 Subject: [PATCH 07/54] connectionSetup() for IBConnection --- src/connection.cc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 24482c7b8..b682903a4 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -116,12 +116,19 @@ void IBConnection::flush() { // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } -void startSetup(Communicator& comm) { - // TODO: use bootstrapper from comm to send over QP info +void IBConnection::startSetup(Communicator& comm) { + comm.bootstrap().send(&qp->info, sizeof(qp->info), remoteRank, tag); } -void endSetup(Communicator& comm) { - // TODO: use bootstrapper from comm to receive QP info and do the rtr/rts calls +void IBConnection::endSetup(Communicator& comm) { + mscclppIbQpInfo qpInfo; + comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); + if (qp->rtr(&qpInfo) != 0) { + throw std::runtime_error("Failed to transition QP to RTR"); + } + if (qp->rts() != 0) { + throw std::runtime_error("Failed to transition QP to RTS"); + } } } // namespace mscclpp From 7c87ca300526663ff0121b98a29f3415c9aff87d Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 27 Apr 2023 00:01:38 +0000 Subject: [PATCH 08/54] Missing functions and TODOs --- TODO.md | 7 +++++ src/communicator.cc | 77 +++++++++++++++++++++++++++++++++++++++++++++ src/connection.cc | 6 ++++ 3 files changed, 90 insertions(+) create mode 100644 TODO.md diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..5338c10ef --- /dev/null +++ b/TODO.md @@ -0,0 +1,7 @@ +# Core API extraction + +- Add a test for host side Communicator/RegisteredMemory/Connection use. +- Implement a standalone "epoch" synchronization construct that can be used as a component in custom proxies. epoch.hpp/cc has the beginnings of this. +- Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. +- Change the new IBConnection and Communicator to use the new C++ IbCtx and IbQp classes. +- Implement IbQp::~IbQp() \ No newline at end of file diff --git a/src/communicator.cc b/src/communicator.cc index 9ce5b7791..ce26d64a5 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -110,4 +110,81 @@ MSCCLPP_API_CPP int Communicator::size() { return result; } +// TODO: move these elsewhere + +int getIBDeviceCount() { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + return num; +} + +std::string getIBDeviceName(TransportFlags ibTransport) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + int ibTransportIndex; + switch (ibTransport) { // TODO: get rid of this ugly switch + case TransportIB0: + ibTransportIndex = 0; + break; + case TransportIB1: + ibTransportIndex = 1; + break; + case TransportIB2: + ibTransportIndex = 2; + break; + case TransportIB3: + ibTransportIndex = 3; + break; + case TransportIB4: + ibTransportIndex = 4; + break; + case TransportIB5: + ibTransportIndex = 5; + break; + case TransportIB6: + ibTransportIndex = 6; + break; + case TransportIB7: + ibTransportIndex = 7; + break; + default: + throw std::runtime_error("Not an IB transport"); + } + if (ibTransportIndex >= num) { + throw std::runtime_error("IB transport out of range"); + } + return devices[ibTransportIndex]->name; +} + +TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (ibDeviceName == devices[i]->name) { + switch (i) { // TODO: get rid of this ugly switch + case 0: + return TransportIB0; + case 1: + return TransportIB1; + case 2: + return TransportIB2; + case 3: + return TransportIB3; + case 4: + return TransportIB4; + case 5: + return TransportIB5; + case 6: + return TransportIB6; + case 7: + return TransportIB7; + default: + throw std::runtime_error("IB device index out of range"); + } + } + } + throw std::runtime_error("IB device not found"); +} + + } // namespace mscclpp \ No newline at end of file diff --git a/src/connection.cc b/src/connection.cc index b682903a4..8d1b5e113 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -11,6 +11,12 @@ void validateTransport(RegisteredMemory mem, TransportFlags transport) { } } +// Connection + +std::shared_ptr Connection::getRegisteredMemoryImpl(RegisteredMemory& mem) { + return mem.pimpl; +} + // CudaIpcConnection CudaIpcConnection::CudaIpcConnection() { From d096874d578d1f8c5e598002c3bd51b7b5972dc7 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 27 Apr 2023 00:22:29 +0000 Subject: [PATCH 09/54] TODO updates --- TODO.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TODO.md b/TODO.md index 5338c10ef..63cb4eb72 100644 --- a/TODO.md +++ b/TODO.md @@ -2,6 +2,6 @@ - Add a test for host side Communicator/RegisteredMemory/Connection use. - Implement a standalone "epoch" synchronization construct that can be used as a component in custom proxies. epoch.hpp/cc has the beginnings of this. -- Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. +- Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. Probably need a manager class to wrap all of this. - Change the new IBConnection and Communicator to use the new C++ IbCtx and IbQp classes. - Implement IbQp::~IbQp() \ No newline at end of file From 0e9f6fadc73b36f0117ccc5845f72c0884d0ce52 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 27 Apr 2023 00:26:00 +0000 Subject: [PATCH 10/54] TODOs --- TODO.md | 3 ++- src/registered_memory.cc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/TODO.md b/TODO.md index 63cb4eb72..677b46cfd 100644 --- a/TODO.md +++ b/TODO.md @@ -4,4 +4,5 @@ - Implement a standalone "epoch" synchronization construct that can be used as a component in custom proxies. epoch.hpp/cc has the beginnings of this. - Reimplement the "standard" proxy service + DeviceConnection on top of the new Communicator/RegisteredMemory/Connection core API. Remants of the old code is in channel.hpp, basic_proxy_handler.hpp/cc and host_connection.hpp/cc. Probably need a manager class to wrap all of this. - Change the new IBConnection and Communicator to use the new C++ IbCtx and IbQp classes. -- Implement IbQp::~IbQp() \ No newline at end of file +- Implement IbQp::~IbQp() +- Fix RegisteredMemory::Impl::Impl to get the IPC handle from the base pointer, not the derived pointer. \ No newline at end of file diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 7a5a0725d..d9476e4f9 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -9,6 +9,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t TransportInfo transportInfo; transportInfo.transport = TransportCudaIpc; cudaIpcMemHandle_t handle; + // TODO: translate data to a base pointer CUDATHROW(cudaIpcGetMemHandle(&handle, data)); transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); From 47d4606f130deb6311a8d23f482f59aa630957d6 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 27 Apr 2023 00:33:24 +0000 Subject: [PATCH 11/54] Add registerMemory --- src/communicator.cc | 7 ++++++- src/include/mscclpp.hpp | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/communicator.cc b/src/communicator.cc index ce26d64a5..c34dbb316 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -8,6 +8,7 @@ #include "checks.hpp" #include "debug.h" #include "connection.hpp" +#include "registered_memory.hpp" namespace mscclpp { @@ -75,6 +76,10 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { mscclppBootstrapBarrier(pimpl->comm); } +RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { + return RegisteredMemory(std::make_shared(ptr, size, pimpl->comm->rank, transports, *pimpl)); +} + MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { std::shared_ptr conn; if (transport | TransportCudaIpc) { @@ -114,7 +119,7 @@ MSCCLPP_API_CPP int Communicator::size() { int getIBDeviceCount() { int num; - struct ibv_device** devices = ibv_get_device_list(&num); + ibv_get_device_list(&num); return num; } diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index f2d8667e1..bd4bc067c 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -102,6 +102,7 @@ class RegisteredMemory { static RegisteredMemory deserialize(const std::vector& data); friend class Connection; + friend class Communicator; }; class Connection { From 08e80f1754527fe9f72026d032c1b08301587a8d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 27 Apr 2023 04:01:46 +0000 Subject: [PATCH 12/54] IB: completely replaced with C++ interfaces --- src/communicator.cc | 87 +---- src/connection.cc | 34 +- src/ib.cc | 574 +++++++++++++----------------- src/include/comm.h | 7 +- src/include/communicator.hpp | 6 +- src/include/connection.hpp | 4 +- src/include/ib.h | 69 ---- src/include/ib.hpp | 53 ++- src/include/mscclpp.h | 17 +- src/include/proxy.h | 2 +- src/include/registered_memory.hpp | 6 +- src/init.cc | 79 ++-- src/proxy.cc | 2 +- src/registered_memory.cc | 5 +- tests/unittests/ib_test.cc | 64 ++-- 15 files changed, 390 insertions(+), 619 deletions(-) delete mode 100644 src/include/ib.h diff --git a/src/communicator.cc b/src/communicator.cc index c34dbb316..6c501d70a 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -16,7 +16,7 @@ Communicator::Impl::Impl() : comm(nullptr) {} Communicator::Impl::~Impl() { for (auto& entry : ibContexts) { - mscclppIbContextDestroy(entry.second); + delete entry.second; } ibContexts.clear(); if (comm) { @@ -24,13 +24,12 @@ Communicator::Impl::~Impl() { } } -mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) { +IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { // Find IB context or create it auto it = ibContexts.find(ibTransport); if (it == ibContexts.end()) { auto ibDev = getIBDeviceName(ibTransport); - mscclppIbContext* ibCtx; - MSCCLPPTHROW(mscclppIbContextCreate(&ibCtx, ibDev.c_str())); + IbCtx* ibCtx = new IbCtx(ibDev); ibContexts[ibTransport] = ibCtx; return ibCtx; } else { @@ -92,6 +91,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank throw std::runtime_error("Unsupported transport"); } pimpl->connections.push_back(conn); + return conn; } MSCCLPP_API_CPP void Communicator::connectionSetup() { @@ -115,81 +115,4 @@ MSCCLPP_API_CPP int Communicator::size() { return result; } -// TODO: move these elsewhere - -int getIBDeviceCount() { - int num; - ibv_get_device_list(&num); - return num; -} - -std::string getIBDeviceName(TransportFlags ibTransport) { - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - int ibTransportIndex; - switch (ibTransport) { // TODO: get rid of this ugly switch - case TransportIB0: - ibTransportIndex = 0; - break; - case TransportIB1: - ibTransportIndex = 1; - break; - case TransportIB2: - ibTransportIndex = 2; - break; - case TransportIB3: - ibTransportIndex = 3; - break; - case TransportIB4: - ibTransportIndex = 4; - break; - case TransportIB5: - ibTransportIndex = 5; - break; - case TransportIB6: - ibTransportIndex = 6; - break; - case TransportIB7: - ibTransportIndex = 7; - break; - default: - throw std::runtime_error("Not an IB transport"); - } - if (ibTransportIndex >= num) { - throw std::runtime_error("IB transport out of range"); - } - return devices[ibTransportIndex]->name; -} - -TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - for (int i = 0; i < num; ++i) { - if (ibDeviceName == devices[i]->name) { - switch (i) { // TODO: get rid of this ugly switch - case 0: - return TransportIB0; - case 1: - return TransportIB1; - case 2: - return TransportIB2; - case 3: - return TransportIB3; - case 4: - return TransportIB4; - case 5: - return TransportIB5; - case 6: - return TransportIB6; - case 7: - return TransportIB7; - default: - throw std::runtime_error("IB device index out of range"); - } - } - } - throw std::runtime_error("IB device not found"); -} - - -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/connection.cc b/src/connection.cc index 8d1b5e113..1e21694cd 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -2,6 +2,7 @@ #include "checks.hpp" #include "registered_memory.hpp" #include "npkit/npkit.h" +#include "infiniband/verbs.h" namespace mscclpp { @@ -54,7 +55,7 @@ void CudaIpcConnection::flush() { // IBConnection IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) { - MSCCLPPTHROW(mscclppIbContextCreateQp(commImpl.getIbContext(transport), &qp)); + qp = commImpl.getIbContext(transport)->createQp(); } IBConnection::~IBConnection() { @@ -85,13 +86,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem auto dstMrInfo = dstTransportInfo.ibMrInfo; auto srcMr = srcTransportInfo.ibMr; - qp->stageSend(srcMr, &dstMrInfo, (uint32_t)size, - /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); - int ret = qp->postSend(); - if (ret != 0) { - // Return value is errno. - WARN("data postSend failed: errno %d", ret); - } + qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); + qp->postSend(); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } @@ -104,15 +100,11 @@ void IBConnection::flush() { continue; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &qp->wcs[i]; + const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); continue; } - if (wc->qp_num != qp->qp->qp_num) { - WARN("got wc of unknown qp_num %d", wc->qp_num); - continue; - } if (wc->opcode == IBV_WC_RDMA_WRITE) { isWaiting = false; break; @@ -123,18 +115,16 @@ void IBConnection::flush() { } void IBConnection::startSetup(Communicator& comm) { - comm.bootstrap().send(&qp->info, sizeof(qp->info), remoteRank, tag); + // TODO(chhwang): temporarily disabled to compile + // comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag); } void IBConnection::endSetup(Communicator& comm) { - mscclppIbQpInfo qpInfo; - comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); - if (qp->rtr(&qpInfo) != 0) { - throw std::runtime_error("Failed to transition QP to RTR"); - } - if (qp->rts() != 0) { - throw std::runtime_error("Failed to transition QP to RTS"); - } + IbQpInfo qpInfo; + // TODO(chhwang): temporarily disabled to compile + // comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); + qp->rtr(qpInfo); + qp->rts(); } } // namespace mscclpp diff --git a/src/ib.cc b/src/ib.cc index 4a0947619..4dc0285ba 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -4,282 +4,175 @@ #include #include #include -#include +#include "mscclpp.hpp" #include "alloc.h" #include "comm.h" #include "debug.h" -#include "ib.h" #include "ib.hpp" #include "checks.hpp" +#include +#include -mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName) -{ - struct mscclppIbContext* _ctx; - MSCCLPPCHECK(mscclppCalloc(&_ctx, 1)); - - std::vector ports; +namespace mscclpp { - int num; - struct ibv_device** devices = ibv_get_device_list(&num); - for (int i = 0; i < num; ++i) { - if (strncmp(devices[i]->name, ibDevName, IBV_SYSFS_NAME_MAX) == 0) { - _ctx->ctx = ibv_open_device(devices[i]); - break; - } +IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) +{ + if (size == 0) { + throw std::runtime_error("invalid size: " + std::to_string(size)); } - ibv_free_device_list(devices); - if (_ctx->ctx == nullptr) { - WARN("ibv_open_device failed (errno %d, device name %s)", errno, ibDevName); - goto fail; + static __thread uintptr_t pageSize = 0; + if (pageSize == 0) { + pageSize = sysconf(_SC_PAGESIZE); } - - // Check available ports - struct ibv_device_attr devAttr; - if (ibv_query_device(_ctx->ctx, &devAttr) != 0) { - WARN("ibv_query_device failed (errno %d, device name %s)", errno, ibDevName); - goto fail; + uintptr_t addr = reinterpret_cast(buff) & -pageSize; + std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; + struct ibv_pd* _pd = reinterpret_cast(pd); + struct ibv_mr* _mr = ibv_reg_mr(_pd, reinterpret_cast(addr), pages * pageSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); + if (_mr == nullptr) { + std::stringstream err; + err << "ibv_reg_mr failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } + this->mr = _mr; + this->size = pages * pageSize; +} - for (uint8_t i = 1; i <= devAttr.phys_port_cnt; ++i) { - struct ibv_port_attr portAttr; - if (ibv_query_port(_ctx->ctx, i, &portAttr) != 0) { - WARN("ibv_query_port failed (errno %d, port %d)", errno, i); - goto fail; - } - if (portAttr.state != IBV_PORT_ACTIVE) { - continue; - } - if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND && portAttr.link_layer != IBV_LINK_LAYER_ETHERNET) { - continue; - } - ports.push_back((int)i); - } - if (ports.size() == 0) { - WARN("no active IB port found"); - goto fail; - } - MSCCLPPCHECK(mscclppCalloc(&_ctx->ports, ports.size())); - _ctx->nPorts = (int)ports.size(); - for (int i = 0; i < _ctx->nPorts; ++i) { - _ctx->ports[i] = ports[i]; - } +IbMr::~IbMr() +{ + ibv_dereg_mr(reinterpret_cast(this->mr)); +} - _ctx->pd = ibv_alloc_pd(_ctx->ctx); - if (_ctx->pd == NULL) { - WARN("ibv_alloc_pd failed (errno %d)", errno); - goto fail; - } +IbMrInfo IbMr::getInfo() const +{ + IbMrInfo info; + info.addr = reinterpret_cast(this->buff); + info.rkey = reinterpret_cast(this->mr)->rkey; + return info; +} - *ctx = _ctx; - return mscclppSuccess; -fail: - *ctx = NULL; - if (_ctx->ports != NULL) { - free(_ctx->ports); - } - free(_ctx); - return mscclppInternalError; +const void* IbMr::getBuff() const +{ + return this->buff; } -mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx) +uint32_t IbMr::getLkey() const { - for (int i = 0; i < ctx->nMrs; ++i) { - if (ctx->mrs[i].mr) { - ibv_dereg_mr(ctx->mrs[i].mr); - } - } - for (int i = 0; i < ctx->nQps; ++i) { - if (ctx->qps[i].qp) { - ibv_destroy_qp(ctx->qps[i].qp); - } - ibv_destroy_cq(ctx->qps[i].cq); - free(ctx->qps[i].wcs); - free(ctx->qps[i].sges); - free(ctx->qps[i].wrs); - } - if (ctx->pd != NULL) { - ibv_dealloc_pd(ctx->pd); - } - if (ctx->ctx != NULL) { - ibv_close_device(ctx->ctx); - } - free(ctx->mrs); - free(ctx->qps); - free(ctx->ports); - free(ctx); - return mscclppSuccess; + return reinterpret_cast(this->mr)->lkey; } -mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port /*=-1*/) +IbQp::IbQp(void* ctx, void* pd, int port) { - if (port < 0) { - port = ctx->ports[0]; - } else { - bool found = false; - for (int i = 0; i < ctx->nPorts; ++i) { - if (ctx->ports[i] == port) { - found = true; - break; - } - } - if (!found) { - WARN("invalid IB port: %d", port); - return mscclppInternalError; - } - } + struct ibv_context* _ctx = reinterpret_cast(ctx); + struct ibv_pd* _pd = reinterpret_cast(pd); - struct ibv_cq* cq = ibv_create_cq(ctx->ctx, MSCCLPP_IB_CQ_SIZE, NULL, NULL, 0); - if (cq == NULL) { - WARN("ibv_create_cq failed (errno %d)", errno); - return mscclppInternalError; + this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); + if (this->cq == nullptr) { + std::stringstream err; + err << "ibv_create_cq failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } - struct ibv_qp_init_attr qp_init_attr; - std::memset(&qp_init_attr, 0, sizeof(struct ibv_qp_init_attr)); - qp_init_attr.sq_sig_all = 0; - qp_init_attr.send_cq = cq; - qp_init_attr.recv_cq = cq; - qp_init_attr.qp_type = IBV_QPT_RC; - qp_init_attr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qp_init_attr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qp_init_attr.cap.max_send_sge = 1; - qp_init_attr.cap.max_recv_sge = 1; - qp_init_attr.cap.max_inline_data = 0; - struct ibv_qp* qp = ibv_create_qp(ctx->pd, &qp_init_attr); - if (qp == nullptr) { - WARN("ibv_create_qp failed (errno %d)", errno); - return mscclppInternalError; - } - struct ibv_port_attr port_attr; - if (ibv_query_port(ctx->ctx, port, &port_attr) != 0) { - WARN("ibv_query_port failed (errno %d, port %d)", errno, port); - return mscclppInternalError; - } + struct ibv_qp_init_attr qpInitAttr; + std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); + qpInitAttr.sq_sig_all = 0; + qpInitAttr.send_cq = reinterpret_cast(this->cq); + qpInitAttr.recv_cq = reinterpret_cast(this->cq); + qpInitAttr.qp_type = IBV_QPT_RC; + qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; + qpInitAttr.cap.max_send_sge = 1; + qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_inline_data = 0; - // Register QP to this ctx - qp->context = ctx->ctx; - if (qp->context == NULL) { - WARN("IB context is NULL"); - return mscclppInternalError; - } - ctx->nQps++; - if (ctx->qps == NULL) { - MSCCLPPCHECK(mscclppCalloc(&ctx->qps, MAXCONNECTIONS)); - ctx->maxQps = MAXCONNECTIONS; + struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr); + if (_qp == nullptr) { + std::stringstream err; + err << "ibv_create_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } - if (ctx->maxQps < ctx->nQps) { - WARN("too many QPs"); - return mscclppInternalError; + + struct ibv_port_attr portAttr; + if (ibv_query_port(_ctx, port, &portAttr) != 0) { + std::stringstream err; + err << "ibv_query_port failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } - struct mscclppIbQp* _ibQp = &ctx->qps[ctx->nQps - 1]; - _ibQp->qp = qp; - _ibQp->info.lid = port_attr.lid; - _ibQp->info.port = port; - _ibQp->info.linkLayer = port_attr.link_layer; - _ibQp->info.qpn = qp->qp_num; - _ibQp->info.mtu = port_attr.active_mtu; - if (port_attr.link_layer != IBV_LINK_LAYER_INFINIBAND) { + this->info.lid = portAttr.lid; + this->info.port = port; + this->info.linkLayer = portAttr.link_layer; + this->info.qpn = _qp->qp_num; + this->info.mtu = portAttr.active_mtu; + if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) { union ibv_gid gid; - if (ibv_query_gid(ctx->ctx, port, 0, &gid) != 0) { - WARN("ibv_query_gid failed (errno %d)", errno); - return mscclppInternalError; + if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { + std::stringstream err; + err << "ibv_query_gid failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } - _ibQp->info.spn = gid.global.subnet_prefix; + this->info.spn = gid.global.subnet_prefix; } - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_INIT; - qp_attr.pkey_index = 0; - qp_attr.port_num = port; - qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; - if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - WARN("ibv_modify_qp failed (errno %d)", errno); - return mscclppInternalError; + struct ibv_qp_attr qpAttr; + memset(&qpAttr, 0, sizeof(qpAttr)); + qpAttr.qp_state = IBV_QPS_INIT; + qpAttr.pkey_index = 0; + qpAttr.port_num = port; + qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } - - MSCCLPPCHECK(mscclppCalloc(&_ibQp->wrs, MSCCLPP_IB_MAX_SENDS)); - MSCCLPPCHECK(mscclppCalloc(&_ibQp->sges, MSCCLPP_IB_MAX_SENDS)); - MSCCLPPCHECK(mscclppCalloc(&_ibQp->wcs, MSCCLPP_IB_CQ_POLL_NUM)); - _ibQp->cq = cq; - - *ibQp = _ibQp; - - return mscclppSuccess; + this->qp = _qp; + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wrs), MSCCLPP_IB_MAX_SENDS)); + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->sges), MSCCLPP_IB_MAX_SENDS)); + MSCCLPPTHROW(mscclppCalloc(reinterpret_cast(&this->wcs), MSCCLPP_IB_CQ_POLL_NUM)); } -mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size, - struct mscclppIbMr** ibMr) +IbQp::~IbQp() { - if (size == 0) { - WARN("invalid size: %zu", size); - return mscclppInvalidArgument; - } - static __thread uintptr_t pageSize = 0; - if (pageSize == 0) { - pageSize = sysconf(_SC_PAGESIZE); - } - uintptr_t addr = reinterpret_cast(buff) & -pageSize; - size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - struct ibv_mr* mr = - ibv_reg_mr(ctx->pd, reinterpret_cast(addr), pages * pageSize, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); - if (mr == nullptr) { - WARN("ibv_reg_mr failed (errno %d)", errno); - return mscclppInternalError; - } - ctx->nMrs++; - if (ctx->mrs == NULL) { - MSCCLPPCHECK(mscclppCalloc(&ctx->mrs, MAXCONNECTIONS)); - ctx->maxMrs = MAXCONNECTIONS; - } - if (ctx->maxMrs < ctx->nMrs) { - WARN("too many MRs"); - return mscclppInternalError; - } - struct mscclppIbMr* _ibMr = &ctx->mrs[ctx->nMrs - 1]; - _ibMr->mr = mr; - _ibMr->buff = buff; - _ibMr->info.addr = (uint64_t)buff; - _ibMr->info.rkey = mr->rkey; - *ibMr = _ibMr; - return mscclppSuccess; + ibv_destroy_qp(reinterpret_cast(this->qp)); + ibv_destroy_cq(reinterpret_cast(this->cq)); + std::free(this->wrs); + std::free(this->sges); + std::free(this->wcs); } -////////////////////////////////////////////////////////////////////////////// - -int mscclppIbQp::rtr(const mscclppIbQpInfo* info) +void IbQp::rtr(const IbQpInfo& info) { struct ibv_qp_attr qp_attr; std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); qp_attr.qp_state = IBV_QPS_RTR; - qp_attr.path_mtu = info->mtu; - qp_attr.dest_qp_num = info->qpn; + qp_attr.path_mtu = static_cast(info.mtu); + qp_attr.dest_qp_num = info.qpn; qp_attr.rq_psn = 0; qp_attr.max_dest_rd_atomic = 1; qp_attr.min_rnr_timer = 0x12; - if (info->linkLayer == IBV_LINK_LAYER_ETHERNET) { + if (info.linkLayer == IBV_LINK_LAYER_ETHERNET) { qp_attr.ah_attr.is_global = 1; - qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info->spn; - qp_attr.ah_attr.grh.dgid.global.interface_id = info->lid; + qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; + qp_attr.ah_attr.grh.dgid.global.interface_id = info.lid; qp_attr.ah_attr.grh.flow_label = 0; qp_attr.ah_attr.grh.sgid_index = 0; qp_attr.ah_attr.grh.hop_limit = 255; qp_attr.ah_attr.grh.traffic_class = 0; } else { qp_attr.ah_attr.is_global = 0; - qp_attr.ah_attr.dlid = info->lid; + qp_attr.ah_attr.dlid = info.lid; } qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; - qp_attr.ah_attr.port_num = info->port; - return ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + qp_attr.ah_attr.port_num = info.port; + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + if (ret != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } } -int mscclppIbQp::rts() +void IbQp::rts() { struct ibv_qp_attr qp_attr; std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); @@ -289,153 +182,103 @@ int mscclppIbQp::rts() qp_attr.rnr_retry = 7; qp_attr.sq_psn = 0; qp_attr.max_rd_atomic = 1; - return ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC); + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); + if (ret != 0) { + std::stringstream err; + err << "ibv_modify_qp failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } } -int mscclppIbQp::stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled) +int IbQp::stageSend(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { return -1; } int wrn = this->wrn; - struct ibv_send_wr* wr_ = &this->wrs[wrn]; - struct ibv_sge* sge_ = &this->sges[wrn]; - // std::memset(wr_, 0, sizeof(struct ibv_send_wr)); - // std::memset(sge_, 0, sizeof(struct ibv_sge)); + struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); + struct ibv_sge* sges_ = reinterpret_cast(this->sges); + + struct ibv_send_wr* wr_ = &wrs_[wrn]; + struct ibv_sge* sge_ = &sges_[wrn]; wr_->wr_id = wrId; wr_->sg_list = sge_; wr_->num_sge = 1; wr_->opcode = IBV_WR_RDMA_WRITE; wr_->send_flags = signaled ? IBV_SEND_SIGNALED : 0; - wr_->wr.rdma.remote_addr = (uint64_t)(info->addr) + dstOffset; - wr_->wr.rdma.rkey = info->rkey; + wr_->wr.rdma.remote_addr = (uint64_t)(info.addr) + dstOffset; + wr_->wr.rdma.rkey = info.rkey; wr_->next = nullptr; - sge_->addr = (uint64_t)(ibMr->buff) + srcOffset; + sge_->addr = (uint64_t)(mr->getBuff()) + srcOffset; sge_->length = size; - sge_->lkey = ibMr->mr->lkey; + sge_->lkey = mr->getLkey(); if (wrn > 0) { - this->wrs[wrn - 1].next = wr_; + wrs_[wrn - 1].next = wr_; } this->wrn++; return this->wrn; } -int mscclppIbQp::stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) +int IbQp::stageSendWithImm(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - int wrn = this->stageSend(ibMr, info, size, wrId, srcOffset, dstOffset, signaled); - this->wrs[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - this->wrs[wrn - 1].imm_data = immData; + int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled); + struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); + wrs_[wrn - 1].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs_[wrn - 1].imm_data = immData; return wrn; } -int mscclppIbQp::postSend() +void IbQp::postSend() { if (this->wrn == 0) { - return 0; + return; } - struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(this->qp, this->wrs, &bad_wr); + int ret = ibv_post_send(reinterpret_cast(this->qp), reinterpret_cast(this->wrs), &bad_wr); if (ret != 0) { - return ret; + std::stringstream err; + err << "ibv_post_send failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); } this->wrn = 0; - return 0; } -int mscclppIbQp::postRecv(uint64_t wrId) +void IbQp::postRecv(uint64_t wrId) { struct ibv_recv_wr wr, *bad_wr; wr.wr_id = wrId; wr.sg_list = nullptr; wr.num_sge = 0; wr.next = nullptr; - return ibv_post_recv(this->qp, &wr, &bad_wr); + int ret = ibv_post_recv(reinterpret_cast(this->qp), &wr, &bad_wr); + if (ret != 0) { + std::stringstream err; + err << "ibv_post_recv failed (errno " << errno << ")"; + throw std::runtime_error(err.str()); + } } -int mscclppIbQp::pollCq() +int IbQp::pollCq() { - return ibv_poll_cq(this->cq, MSCCLPP_IB_CQ_POLL_NUM, this->wcs); + return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast(this->wcs)); } -namespace mscclpp { - -IbQp::IbQp(void* ctx, void* pd, int port) +const IbQpInfo& IbQp::getInfo() const { - struct ibv_context* _ctx = static_cast(ctx); - struct ibv_pd* _pd = static_cast(pd); - - this->cq = ibv_create_cq(_ctx, MSCCLPP_IB_CQ_SIZE, nullptr, nullptr, 0); - if (this->cq == nullptr) { - std::stringstream err; - err << "ibv_create_cq failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); - } - - struct ibv_qp_init_attr qpInitAttr; - std::memset(&qpInitAttr, 0, sizeof(qpInitAttr)); - qpInitAttr.sq_sig_all = 0; - qpInitAttr.send_cq = static_cast(this->cq); - qpInitAttr.recv_cq = static_cast(this->cq); - qpInitAttr.qp_type = IBV_QPT_RC; - qpInitAttr.cap.max_send_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qpInitAttr.cap.max_recv_wr = MAXCONNECTIONS * MSCCLPP_PROXY_FIFO_SIZE; - qpInitAttr.cap.max_send_sge = 1; - qpInitAttr.cap.max_recv_sge = 1; - qpInitAttr.cap.max_inline_data = 0; - - struct ibv_qp* _qp = ibv_create_qp(_pd, &qpInitAttr); - if (_qp == nullptr) { - std::stringstream err; - err << "ibv_create_qp failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); - } - - struct ibv_port_attr portAttr; - if (ibv_query_port(_ctx, port, &portAttr) != 0) { - std::stringstream err; - err << "ibv_query_port failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); - } - this->info.lid = portAttr.lid; - this->info.port = port; - this->info.linkLayer = portAttr.link_layer; - this->info.qpn = _qp->qp_num; - this->info.mtu = portAttr.active_mtu; - if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND) { - union ibv_gid gid; - if (ibv_query_gid(_ctx, port, 0, &gid) != 0) { - std::stringstream err; - err << "ibv_query_gid failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); - } - this->info.spn = gid.global.subnet_prefix; - } + return this->info; +} - struct ibv_qp_attr qpAttr; - memset(&qpAttr, 0, sizeof(qpAttr)); - qpAttr.qp_state = IBV_QPS_INIT; - qpAttr.pkey_index = 0; - qpAttr.port_num = port; - qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; - if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - std::stringstream err; - err << "ibv_modify_qp failed (errno " << errno << ")"; - throw std::runtime_error(err.str()); - } - this->qp = _qp; +const void* IbQp::getWc(int idx) const +{ + return &reinterpret_cast(this->wcs)[idx]; } -IbCtx::IbCtx(const std::string& ibDevName) +IbCtx::IbCtx(const std::string& devName) : devName(devName) { int num; struct ibv_device** devices = ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { - if (std::string(devices[i]->name) == ibDevName) { + if (std::string(devices[i]->name) == devName) { this->ctx = ibv_open_device(devices[i]); break; } @@ -443,10 +286,10 @@ IbCtx::IbCtx(const std::string& ibDevName) ibv_free_device_list(devices); if (this->ctx == nullptr) { std::stringstream err; - err << "ibv_open_device failed (errno " << errno << ", device name << " << ibDevName << ")"; + err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")"; throw std::runtime_error(err.str()); } - this->pd = ibv_alloc_pd(static_cast(this->ctx)); + this->pd = ibv_alloc_pd(reinterpret_cast(this->ctx)); if (this->pd == nullptr) { std::stringstream err; err << "ibv_alloc_pd failed (errno " << errno << ")"; @@ -456,18 +299,20 @@ IbCtx::IbCtx(const std::string& ibDevName) IbCtx::~IbCtx() { + this->mrs.clear(); + this->qps.clear(); if (this->pd != nullptr) { - ibv_dealloc_pd(static_cast(this->pd)); + ibv_dealloc_pd(reinterpret_cast(this->pd)); } if (this->ctx != nullptr) { - ibv_close_device(static_cast(this->ctx)); + ibv_close_device(reinterpret_cast(this->ctx)); } } bool IbCtx::isPortUsable(int port) const { struct ibv_port_attr portAttr; - if (ibv_query_port(static_cast(this->ctx), port, &portAttr) != 0) { + if (ibv_query_port(reinterpret_cast(this->ctx), port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; throw std::runtime_error(err.str()); @@ -479,7 +324,7 @@ bool IbCtx::isPortUsable(int port) const int IbCtx::getAnyActivePort() const { struct ibv_device_attr devAttr; - if (ibv_query_device(static_cast(this->ctx), &devAttr) != 0) { + if (ibv_query_device(reinterpret_cast(this->ctx), &devAttr) != 0) { std::stringstream err; err << "ibv_query_device failed (errno " << errno << ")"; throw std::runtime_error(err.str()); @@ -506,4 +351,89 @@ IbQp* IbCtx::createQp(int port /*=-1*/) return qps.back().get(); } +const IbMr* IbCtx::registerMr(void* buff, std::size_t size) +{ + mrs.emplace_back(new IbMr(this->pd, buff, size)); + return mrs.back().get(); +} + +const std::string& IbCtx::getDevName() const +{ + return this->devName; +} + +int getIBDeviceCount() { + int num; + ibv_get_device_list(&num); + return num; +} + +std::string getIBDeviceName(TransportFlags ibTransport) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + int ibTransportIndex; + switch (ibTransport) { // TODO: get rid of this ugly switch + case TransportIB0: + ibTransportIndex = 0; + break; + case TransportIB1: + ibTransportIndex = 1; + break; + case TransportIB2: + ibTransportIndex = 2; + break; + case TransportIB3: + ibTransportIndex = 3; + break; + case TransportIB4: + ibTransportIndex = 4; + break; + case TransportIB5: + ibTransportIndex = 5; + break; + case TransportIB6: + ibTransportIndex = 6; + break; + case TransportIB7: + ibTransportIndex = 7; + break; + default: + throw std::runtime_error("Not an IB transport"); + } + if (ibTransportIndex >= num) { + throw std::runtime_error("IB transport out of range"); + } + return devices[ibTransportIndex]->name; +} + +TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { + int num; + struct ibv_device** devices = ibv_get_device_list(&num); + for (int i = 0; i < num; ++i) { + if (ibDeviceName == devices[i]->name) { + switch (i) { // TODO: get rid of this ugly switch + case 0: + return TransportIB0; + case 1: + return TransportIB1; + case 2: + return TransportIB2; + case 3: + return TransportIB3; + case 4: + return TransportIB4; + case 5: + return TransportIB5; + case 6: + return TransportIB6; + case 7: + return TransportIB7; + default: + throw std::runtime_error("IB device index out of range"); + } + } + } + throw std::runtime_error("IB device not found"); +} + } // namespace mscclpp diff --git a/src/include/comm.h b/src/include/comm.h index 8275e0cba..dce724fa6 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -7,9 +7,10 @@ #ifndef MSCCLPP_COMM_H_ #define MSCCLPP_COMM_H_ -#include "ib.h" +#include "ib.hpp" #include "proxy.h" #include +#include #define MAXCONNECTIONS 64 @@ -31,7 +32,7 @@ struct mscclppConn std::vector bufferRegistrations; std::vector remoteBufferRegistrations; - struct mscclppIbContext* ibCtx; + mscclpp::IbCtx* ibCtx; #if defined(ENABLE_NPKIT) std::vector npkitUsedReqIds; std::vector npkitFreeReqIds; @@ -57,7 +58,7 @@ struct mscclppComm // Flag to ask MSCCLPP kernels to abort volatile uint32_t* abortFlag; - struct mscclppIbContext* ibContext[MSCCLPP_IB_MAX_DEVS]; + std::unique_ptr ibContext[MSCCLPP_IB_MAX_DEVS]; struct mscclppProxyState* proxyState[MSCCLPP_PROXY_MAX_NUM]; }; diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 879501c0c..37abb31ba 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -5,7 +5,7 @@ #include "mscclpp.h" #include "channel.hpp" #include "proxy.hpp" -#include "ib.h" +#include "ib.hpp" #include namespace mscclpp { @@ -15,13 +15,13 @@ class ConnectionBase; struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - std::unordered_map ibContexts; + std::unordered_map ibContexts; Impl(); ~Impl(); - mscclppIbContext* getIbContext(TransportFlags ibTransport); + IbCtx* getIbContext(TransportFlags ibTransport); }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index ac1dd6a17..dcf21362f 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -3,7 +3,7 @@ #include "mscclpp.hpp" #include -#include "ib.h" +#include "ib.hpp" #include "communicator.hpp" namespace mscclpp { @@ -38,7 +38,7 @@ class IBConnection : public ConnectionBase { int tag; TransportFlags transport_; TransportFlags remoteTransport_; - mscclppIbQp* qp; + IbQp* qp; public: IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl); diff --git a/src/include/ib.h b/src/include/ib.h deleted file mode 100644 index 7494ab110..000000000 --- a/src/include/ib.h +++ /dev/null @@ -1,69 +0,0 @@ -#ifndef MSCCLPP_IB_H_ -#define MSCCLPP_IB_H_ - -#include "mscclpp.h" -#include -#include -#include -#include - -#define MSCCLPP_IB_CQ_SIZE 1024 -#define MSCCLPP_IB_CQ_POLL_NUM 4 -#define MSCCLPP_IB_MAX_SENDS 64 -#define MSCCLPP_IB_MAX_DEVS 8 - -// QP info to be shared with the remote peer -struct mscclppIbQpInfo -{ - uint16_t lid; - uint8_t port; - uint8_t linkLayer; - uint32_t qpn; - uint64_t spn; - ibv_mtu mtu; -}; - -// IB queue pair -struct mscclppIbQp -{ - struct ibv_qp* qp; - struct mscclppIbQpInfo info; - struct ibv_send_wr* wrs; - struct ibv_sge* sges; - struct ibv_cq* cq; - struct ibv_wc* wcs; - int wrn; - - int rtr(const mscclppIbQpInfo* info); - int rts(); - int stageSend(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled); - int stageSendWithImm(struct mscclppIbMr* ibMr, const mscclppIbMrInfo* info, uint32_t size, uint64_t wrId, - uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); - int postSend(); - int postRecv(uint64_t wrId); - int pollCq(); -}; - -// Holds resources of a single IB device. -struct mscclppIbContext -{ - struct ibv_context* ctx; - struct ibv_pd* pd; - int* ports; - int nPorts; - struct mscclppIbQp* qps; - int nQps; - int maxQps; - struct mscclppIbMr* mrs; - int nMrs; - int maxMrs; -}; - -mscclppResult_t mscclppIbContextCreate(struct mscclppIbContext** ctx, const char* ibDevName); -mscclppResult_t mscclppIbContextDestroy(struct mscclppIbContext* ctx); -mscclppResult_t mscclppIbContextCreateQp(struct mscclppIbContext* ctx, struct mscclppIbQp** ibQp, int port = -1); -mscclppResult_t mscclppIbContextRegisterMr(struct mscclppIbContext* ctx, void* buff, size_t size, - struct mscclppIbMr** ibMr); - -#endif diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 85c92af78..d04b75bd2 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -5,8 +5,38 @@ #include #include +#define MSCCLPP_IB_CQ_SIZE 1024 +#define MSCCLPP_IB_CQ_POLL_NUM 1 +#define MSCCLPP_IB_MAX_SENDS 64 +#define MSCCLPP_IB_MAX_DEVS 8 + namespace mscclpp { +struct IbMrInfo +{ + uint64_t addr; + uint32_t rkey; +}; + +class IbMr +{ +public: + ~IbMr(); + + IbMrInfo getInfo() const; + const void* getBuff() const; + uint32_t getLkey() const; + +private: + IbMr(void* pd, void* buff, std::size_t size); + + void* mr; + void* buff; + std::size_t size; + + friend class IbCtx; +}; + // QP info to be shared with the remote peer struct IbQpInfo { @@ -15,7 +45,7 @@ struct IbQpInfo uint8_t linkLayer; uint32_t qpn; uint64_t spn; - uint32_t mtu; + int mtu; }; class IbQp @@ -23,11 +53,22 @@ class IbQp public: ~IbQp(); - IbQpInfo info; + void rtr(const IbQpInfo& info); + void rts(); + int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); + int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void postSend(); + void postRecv(uint64_t wrId); + int pollCq(); + + const IbQpInfo& getInfo() const; + const void* getWc(int idx) const; private: IbQp(void* ctx, void* pd, int port); + IbQpInfo info; + void* qp; void* cq; void* wcs; @@ -38,22 +79,26 @@ class IbQp friend class IbCtx; }; - class IbCtx { public: - IbCtx(const std::string& ibDevName); + IbCtx(const std::string& devName); ~IbCtx(); IbQp* createQp(int port = -1); + const IbMr* registerMr(void* buff, std::size_t size); + + const std::string& getDevName() const; private: bool isPortUsable(int port) const; int getAnyActivePort() const; + const std::string devName; void* ctx; void* pd; std::list> qps; + std::list> mrs; }; } // namespace mscclpp diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index 6f96af103..c01246abe 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -207,25 +207,10 @@ typedef struct char internal[MSCCLPP_UNIQUE_ID_BYTES]; } mscclppUniqueId; -// MR info to be shared with the remote peer -struct mscclppIbMrInfo -{ - uint64_t addr; - uint32_t rkey; -}; - -// IB memory region -struct mscclppIbMr -{ - struct ibv_mr* mr; - void* buff; - struct mscclppIbMrInfo info; -}; - struct mscclppRegisteredMemoryP2P { void* remoteBuff; - mscclppIbMr* IbMr; + const void* IbMr; }; struct mscclppRegisteredMemory diff --git a/src/include/proxy.h b/src/include/proxy.h index 3da0196c7..3746806b7 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -59,7 +59,7 @@ struct mscclppProxyState mscclppProxyRunState_t run; int numaNodeToBind; - struct mscclppIbContext* ibContext; // For IB connection only + mscclpp::IbCtx* ibContext; // For IB connection only cudaStream_t p2pStream; // for P2P DMA engine only struct mscclppProxyFifo fifo; diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 7a0ab1d02..d2270d468 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -3,7 +3,7 @@ #include "mscclpp.hpp" #include "mscclpp.h" -#include "ib.h" +#include "ib.hpp" #include "communicator.hpp" #include @@ -16,8 +16,8 @@ struct TransportInfo { bool ibLocal; union { cudaIpcMemHandle_t cudaIpcHandle; - mscclppIbMr* ibMr; - mscclppIbMrInfo ibMrInfo; + const IbMr* ibMr; + IbMrInfo ibMrInfo; }; }; diff --git a/src/init.cc b/src/init.cc index 7cf159c82..c5b6a66b6 100644 --- a/src/init.cc +++ b/src/init.cc @@ -7,6 +7,7 @@ #include "gdr.h" #endif #include "mscclpp.h" +#include "infiniband/verbs.h" #include #include #include @@ -191,7 +192,7 @@ MSCCLPP_API mscclppResult_t mscclppCommDestroy(mscclppComm_t comm) for (int i = 0; i < MSCCLPP_IB_MAX_DEVS; ++i) { if (comm->ibContext[i]) { - MSCCLPPCHECK(mscclppIbContextDestroy(comm->ibContext[i])); + comm->ibContext[i].reset(nullptr); } } @@ -366,24 +367,17 @@ struct mscclppHostIBConn : mscclppHostConn } void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) { - this->ibQp->stageSend(this->ibMrs[src], &this->remoteIbMrInfos[dst], (uint32_t)dataSize, + this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize, /*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false); - int ret = this->ibQp->postSend(); - if (ret != 0) { - // Return value is errno. - WARN("data postSend failed: errno %d", ret); - } + this->ibQp->postSend(); npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)dataSize); } void signal() { // My local device flag is copied to the remote's proxy flag - this->ibQp->stageSend(this->ibMrs[0], &this->remoteIbMrInfos[0], sizeof(uint64_t), + this->ibQp->stageSend(this->ibMrs[0], this->remoteIbMrInfos[0], sizeof(uint64_t), /*wrId=*/0, /*srcOffset=*/0, /*dstOffset=*/sizeof(uint64_t), /*signaled=*/true); - int ret = this->ibQp->postSend(); - if (ret != 0) { - WARN("flag postSend failed: errno %d", ret); - } + this->ibQp->postSend(); npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_FLAG_ENTRY, (uint32_t)sizeof(uint64_t)); } void wait() @@ -399,15 +393,11 @@ struct mscclppHostIBConn : mscclppHostConn continue; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &this->ibQp->wcs[i]; + struct ibv_wc* wc = (struct ibv_wc*)this->ibQp->getWc(i); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); continue; } - if (wc->qp_num != this->ibQp->qp->qp_num) { - WARN("got wc of unknown qp_num %d", wc->qp_num); - continue; - } if (wc->opcode == IBV_WC_RDMA_WRITE) { isWaiting = false; break; @@ -418,9 +408,9 @@ struct mscclppHostIBConn : mscclppHostConn } mscclppConn* conn; - struct mscclppIbQp* ibQp; - std::vector ibMrs; - std::vector remoteIbMrInfos; + mscclpp::IbQp* ibQp; + std::vector ibMrs; + std::vector remoteIbMrInfos; }; MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev) @@ -458,7 +448,7 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int if (firstNullIdx == -1) { firstNullIdx = i; } - } else if (strncmp(comm->ibContext[i]->ctx->device->name, ibDev, IBV_SYSFS_NAME_MAX) == 0) { + } else if (strncmp(comm->ibContext[i]->getDevName().c_str(), ibDev, IBV_SYSFS_NAME_MAX) == 0) { ibDevIdx = i; break; } @@ -468,13 +458,10 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int if (ibDevIdx == -1) { // Create a new context. ibDevIdx = firstNullIdx; - if (mscclppIbContextCreate(&comm->ibContext[ibDevIdx], ibDev) != mscclppSuccess) { - WARN("Failed to create IB context"); - return mscclppInternalError; - } + comm->ibContext[ibDevIdx].reset(new mscclpp::IbCtx(std::string(ibDev))); } // Set the ib context for this conn - conn->ibCtx = comm->ibContext[ibDevIdx]; + conn->ibCtx = comm->ibContext[ibDevIdx].get(); } else if (transportType == mscclppTransportP2P) { // do the rest of the initialization later @@ -609,17 +596,17 @@ MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t com struct mscclppBufferRegistrationInfo { cudaIpcMemHandle_t cudaHandle; - mscclppIbMrInfo ibMrInfo; + mscclpp::IbMrInfo ibMrInfo; uint64_t size; }; struct connInfo { - mscclppIbQpInfo infoQp; + mscclpp::IbQpInfo infoQp; std::vector bufferInfos; struct header { - mscclppIbQpInfo infoQp; + mscclpp::IbQpInfo infoQp; int numBufferInfos; }; @@ -702,22 +689,20 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output devConn->remoteBuff = NULL; devConn->remoteSignalEpochId = NULL; - struct mscclppIbContext* ibCtx = conn->ibCtx; + mscclpp::IbCtx* ibCtx = conn->ibCtx; if (hostConn->ibQp == NULL) { - MSCCLPPCHECK(mscclppIbContextCreateQp(ibCtx, &hostConn->ibQp)); + hostConn->ibQp = ibCtx->createQp(); } // Add all registered buffers for (const auto &bufReg : conn->bufferRegistrations) { - hostConn->ibMrs.emplace_back(); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ibCtx, bufReg.data, - sizeof(struct mscclppDevConnSignalEpochId), &hostConn->ibMrs.back())); + hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId))); connInfo->bufferInfos.emplace_back(); - connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->info; + connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo(); connInfo->bufferInfos.back().size = bufReg.size; } - connInfo->infoQp = hostConn->ibQp->info; + connInfo->infoQp = hostConn->ibQp->getInfo(); return mscclppSuccess; } @@ -728,14 +713,8 @@ mscclppResult_t mscclppIbConnectionSetupEnd(struct connInfo* connInfo /*input*/, return mscclppInternalError; } struct mscclppHostIBConn* hostConn = (struct mscclppHostIBConn*)conn->hostConn; - if (hostConn->ibQp->rtr(&connInfo->infoQp) != 0) { - WARN("Failed to transition QP to RTR"); - return mscclppInvalidUsage; - } - if (hostConn->ibQp->rts() != 0) { - WARN("Failed to transition QP to RTS"); - return mscclppInvalidUsage; - } + hostConn->ibQp->rtr(connInfo->infoQp); + hostConn->ibQp->rts(); // No remote pointers to set with IB, so we just set the Mrs @@ -788,25 +767,25 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) struct bufferInfo { cudaIpcMemHandle_t handleBuff; - mscclppIbMrInfo infoBuffMr; + mscclpp::IbMrInfo infoBuffMr; }; MSCCLPP_API mscclppResult_t mscclppRegisterBuffer(mscclppComm_t comm, void* local_memory, size_t size, mscclppRegisteredMemory* regMem) { - std::vector ibMrs; + std::vector ibMrs; for (int i = 0; i < comm->nConns; ++i) { struct mscclppConn* conn = &comm->conns[i]; struct bufferInfo bInfo; - struct mscclppIbMr* ibBuffMr; + const mscclpp::IbMr* ibBuffMr; // TODO: (conn->transport & mscclppTransportP2P) to support both P2P and IB if (conn->transport == mscclppTransportP2P) { CUDACHECK(cudaIpcGetMemHandle(&bInfo.handleBuff, local_memory)); } else if (conn->transport == mscclppTransportIB) { - MSCCLPPCHECK(mscclppIbContextRegisterMr(conn->ibCtx, local_memory, size, &ibBuffMr)); - bInfo.infoBuffMr = ibBuffMr->info; - ibMrs.push_back(ibBuffMr); + ibBuffMr = conn->ibCtx->registerMr(local_memory, size); + bInfo.infoBuffMr = ibBuffMr->getInfo(); + ibMrs.emplace_back(ibBuffMr); } MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &bInfo, sizeof(bInfo))); diff --git a/src/proxy.cc b/src/proxy.cc index 6cfd799bc..c8bf44145 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -2,7 +2,7 @@ #include "checks.h" #include "comm.h" #include "debug.h" -#include "ib.h" +#include "ib.hpp" #include "socket.h" #include diff --git a/src/registered_memory.cc b/src/registered_memory.cc index d9476e4f9..f0db85ce4 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -18,8 +18,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t auto addIb = [&](TransportFlags ibTransport) { TransportInfo transportInfo; transportInfo.transport = ibTransport; - mscclppIbMr* mr; - MSCCLPPTHROW(mscclppIbContextRegisterMr(commImpl.getIbContext(ibTransport), data, size, &mr)); + const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); transportInfo.ibMr = mr; transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); @@ -103,7 +102,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { it += sizeof(handle); transportInfo.cudaIpcHandle = handle; } else if (transportInfo.transport & TransportAllIB) { - mscclppIbMrInfo info; + IbMrInfo info; std::copy_n(it, sizeof(info), reinterpret_cast(&info)); it += sizeof(info); transportInfo.ibMrInfo = info; diff --git a/tests/unittests/ib_test.cc b/tests/unittests/ib_test.cc index 2c194eafe..6f84398f6 100644 --- a/tests/unittests/ib_test.cc +++ b/tests/unittests/ib_test.cc @@ -1,8 +1,10 @@ #include "alloc.h" #include "checks.h" -#include "ib.h" -#include +#include "ib.hpp" +#include "infiniband/verbs.h" +#include "mscclpp.hpp" #include +#include // Measure current time in second. static double getTime(void) @@ -24,8 +26,8 @@ int main(int argc, const char* argv[]) printf("Usage: %s <0(recv)/1(send)> \n", argv[0]); return 1; } - const char* ip_port = argv[1]; - int is_send = atoi(argv[2]); + const char* ipPortPair = argv[1]; + int isSend = atoi(argv[2]); int cudaDevId = atoi(argv[3]); std::string ibDevName = "mlx5_ib" + std::string(argv[4]); @@ -35,51 +37,40 @@ int main(int argc, const char* argv[]) int nelem = 1; MSCCLPPCHECK(mscclppCudaCalloc(&data, nelem)); - mscclppComm_t comm; - MSCCLPPCHECK(mscclppCommInitRank(&comm, 2, ip_port, is_send)); + std::shared_ptr bootstrap(new mscclpp::Bootstrap(isSend, 2)); + bootstrap->initialize(ipPortPair); - struct mscclppIbContext* ctx; - struct mscclppIbQp* qp; - struct mscclppIbMr* mr; - MSCCLPPCHECK(mscclppIbContextCreate(&ctx, ibDevName.c_str())); - MSCCLPPCHECK(mscclppIbContextCreateQp(ctx, &qp)); - MSCCLPPCHECK(mscclppIbContextRegisterMr(ctx, data, sizeof(int) * nelem, &mr)); + mscclpp::IbCtx ctx(ibDevName); + mscclpp::IbQp* qp = ctx.createQp(); + const mscclpp::IbMr* mr = ctx.registerMr(data, sizeof(int) * nelem); - struct mscclppIbQpInfo* qpInfo; - MSCCLPPCHECK(mscclppCalloc(&qpInfo, 2)); - qpInfo[is_send] = qp->info; + std::array qpInfo; + qpInfo[isSend] = qp->getInfo(); - struct mscclppIbMrInfo* mrInfo; - MSCCLPPCHECK(mscclppCalloc(&mrInfo, 2)); - mrInfo[is_send] = mr->info; + std::array mrInfo; + mrInfo[isSend] = mr->getInfo(); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, qpInfo, sizeof(struct mscclppIbQpInfo))); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, mrInfo, sizeof(struct mscclppIbMrInfo))); + bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); + bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo)); - for (int i = 0; i < 2; ++i) { - if (i == is_send) + for (int i = 0; i < bootstrap->getNranks(); ++i) { + if (i == isSend) continue; - qp->rtr(&qpInfo[i]); + qp->rtr(qpInfo[i]); qp->rts(); break; } printf("connection succeed\n"); - // A simple barrier - int* tmp; - MSCCLPPCHECK(mscclppCalloc(&tmp, 2)); - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); + bootstrap->barrier(); - if (is_send) { + if (isSend) { int maxIter = 100000; double start = getTime(); for (int iter = 0; iter < maxIter; ++iter) { - qp->stageSend(mr, &mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true); - if (qp->postSend() != 0) { - WARN("postSend failed"); - return 1; - } + qp->stageSend(mr, mrInfo[0], sizeof(int) * nelem, 0, 0, 0, true); + qp->postSend(); bool waiting = true; while (waiting) { int wcNum = qp->pollCq(); @@ -88,7 +79,7 @@ int main(int argc, const char* argv[]) return 1; } for (int i = 0; i < wcNum; ++i) { - struct ibv_wc* wc = &qp->wcs[i]; + const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { WARN("wc status %d", wc->status); return 1; @@ -103,10 +94,7 @@ int main(int argc, const char* argv[]) } // A simple barrier - MSCCLPPCHECK(mscclppBootstrapAllGather(comm, tmp, sizeof(int))); - - MSCCLPPCHECK(mscclppIbContextDestroy(ctx)); - MSCCLPPCHECK(mscclppCommDestroy(comm)); + bootstrap->barrier(); return 0; } From 76410382468b78d90a31917f9b09e03bec8847bc Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 04:15:24 +0000 Subject: [PATCH 13/54] wip --- src/communicator.cc | 5 +---- src/include/communicator.hpp | 3 ++- src/include/mscclpp.hpp | 31 +++---------------------------- tests/bootstrap_test_cpp.cc | 4 ++-- 4 files changed, 8 insertions(+), 35 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 9ce5b7791..d905748a2 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -11,16 +11,13 @@ namespace mscclpp { -Communicator::Impl::Impl() : comm(nullptr) {} +Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) {} Communicator::Impl::~Impl() { for (auto& entry : ibContexts) { mscclppIbContextDestroy(entry.second); } ibContexts.clear(); - if (comm) { - mscclppCommDestroy(comm); - } } mscclppIbContext* Communicator::Impl::getIbContext(TransportFlags ibTransport) { diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 879501c0c..7c5289ad0 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -16,8 +16,9 @@ struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; std::unordered_map ibContexts; + std::shared_ptr bootstrap_; - Impl(); + Impl(std::shared_ptr bootstrap); ~Impl(); diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index f2d8667e1..f7e158725 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -120,24 +120,13 @@ class Connection { class Communicator { public: - /* Initialize the communicator. nranks processes with rank 0 to nranks-1 need to call this function. + /* Initialize the communicator. * * Inputs: - * nranks: number of ranks in the communicator - * ipPortPair: a string of the form "ip:port" that represents the address of the root process - * rank: rank of the calling process + * bootstrap: an implementation of the of BaseBootstrap that the communicator will use */ - Communicator(int nranks, const char* ipPortPair, int rank); + Communicator(std::shared_ptr bootstrap); - /* Initialize the communicator from a given UniqueId. Same as mscclppCommInitRank() except that - * id is provided by the user by calling getUniqueId() - * - * Inputs: - * nranks: number of ranks in the communicator - * id: the unique ID to be used for communication - * rank: rank of the calling process - */ - Communicator(int nranks, UniqueId id, int rank); ~Communicator(); @@ -183,20 +172,6 @@ class Communicator { */ void connectionSetup(); - /* Return the rank of the calling process. - * - * Outputs: - * rank: the rank of the calling process - */ - int rank(); - - /* Return the number of ranks of the communicator. - * - * Outputs: - * size: the number of ranks of the communicator - */ - int size(); - struct Impl; private: std::unique_ptr pimpl; diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index 34e58b598..6c29e369d 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -55,11 +55,11 @@ void test_sendrecv(std::shared_ptr bootstrap){ void test_all(std::shared_ptr bootstrap){ test_allgather(bootstrap); test_barrier(bootstrap); - // test_sendrecv(bootstrap); + test_sendrecv(bootstrap); } void test_mscclpp_bootstrap_with_id(int rank, int worldSize){ - std::shared_ptr bootstrap(new mscclpp::Bootstrap(rank, worldSize)); + auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; if (bootstrap->getRank() == 0) id = bootstrap->createUniqueId(); From c24896b62f4d7e906bf2121303837cbae0bd3abd Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 04:23:44 +0000 Subject: [PATCH 14/54] bootstrap to the communicator --- src/communicator.cc | 22 +--------------------- tests/bootstrap_test_cpp.cc | 4 ++-- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 81753fb60..02ee7a872 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -54,15 +54,7 @@ static mscclppTransport_t transportToCStyle(TransportFlags flags) { } } -MSCCLPP_API_CPP Communicator::Communicator(int nranks, const char* ipPortPair, int rank) : pimpl(std::make_unique()) { - mscclppCommInitRank(&pimpl->comm, nranks, ipPortPair, rank); -} - -MSCCLPP_API_CPP Communicator::Communicator(int nranks, UniqueId id, int rank) : pimpl(std::make_unique()) { - static_assert(sizeof(mscclppUniqueId) == sizeof(UniqueId), "UniqueId size mismatch"); - mscclppUniqueId *cstyle_id = reinterpret_cast(&id); - mscclppCommInitRankFromId(&pimpl->comm, nranks, *cstyle_id, rank); -} +MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) : pimpl(std::make_unique(bootstrap)) {} MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { mscclppBootstrapAllGather(pimpl->comm, data, size); @@ -100,16 +92,4 @@ MSCCLPP_API_CPP void Communicator::connectionSetup() { } } -MSCCLPP_API_CPP int Communicator::rank() { - int result; - mscclppCommRank(pimpl->comm, &result); - return result; -} - -MSCCLPP_API_CPP int Communicator::size() { - int result; - mscclppCommSize(pimpl->comm, &result); - return result; -} - } // namespace mscclpp diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index 6c29e369d..bdde84673 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -24,7 +24,7 @@ void test_barrier(std::shared_ptr bootstrap){ void test_sendrecv(std::shared_ptr bootstrap){ for (int i = 0; i < bootstrap->getNranks(); i++) { - if (bootstrap->getRank() == 0) + if (bootstrap->getRank() == i) continue; int msg1 = (bootstrap->getRank() + 1) * 3; int msg2 = (bootstrap->getRank() + 1) * 3 + 1; @@ -35,7 +35,7 @@ void test_sendrecv(std::shared_ptr bootstrap){ } for (int i = 0; i < bootstrap->getNranks(); i++) { - if (i == bootstrap->getRank()) + if (bootstrap->getRank() == i) continue; int msg1 = 0; int msg2 = 0; From b0c7e869099a5d45222dadda83ea4caee33d2f2b Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 27 Apr 2023 05:01:07 +0000 Subject: [PATCH 15/54] Communicator owns IB contexts --- src/communicator.cc | 10 +++------- src/include/communicator.hpp | 3 ++- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 02ee7a872..c24b0c5e7 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -15,9 +15,6 @@ namespace mscclpp { Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) {} Communicator::Impl::~Impl() { - for (auto& entry : ibContexts) { - delete entry.second; - } ibContexts.clear(); } @@ -26,11 +23,10 @@ IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { auto it = ibContexts.find(ibTransport); if (it == ibContexts.end()) { auto ibDev = getIBDeviceName(ibTransport); - IbCtx* ibCtx = new IbCtx(ibDev); - ibContexts[ibTransport] = ibCtx; - return ibCtx; + ibContexts[ibTransport] = std::make_unique(ibDev); + return ibContexts[ibTransport].get(); } else { - return it->second; + return it->second.get(); } } diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 3c3737aee..53d0fd73f 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -7,6 +7,7 @@ #include "proxy.hpp" #include "ib.hpp" #include +#include namespace mscclpp { @@ -15,7 +16,7 @@ class ConnectionBase; struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - std::unordered_map ibContexts; + std::unordered_map> ibContexts; std::shared_ptr bootstrap_; Impl(std::shared_ptr bootstrap); From df80d8854bdd8bf89773053a42209eff3784b11e Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 05:26:08 +0000 Subject: [PATCH 16/54] connect test --- Makefile | 2 +- tests/communicator_test_cpp.cc | 48 ++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 tests/communicator_test_cpp.cc diff --git a/Makefile b/Makefile index b2d2cceb9..950751d79 100644 --- a/Makefile +++ b/Makefile @@ -149,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc) # allgather_test_cpp.cu +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc communicator_test_cpp.cc) # allgather_test_cpp.cu TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc new file mode 100644 index 000000000..fc3a72e81 --- /dev/null +++ b/tests/communicator_test_cpp.cc @@ -0,0 +1,48 @@ +#include "mscclpp.hpp" + +#include +#include +#include +#include + +void test_communicator(int rank, int worldSize, int nranksPerNode){ + auto bootstrap = std::make_shared(rank, worldSize); + mscclpp::UniqueId id; + if (bootstrap->getRank() == 0) + id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + auto communicator = std::make_shared(bootstrap); + for (int i = 0; i < worldSize; i++){ + if (i != rank){ + if (i % nranksPerNode == rank % nranksPerNode) + auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); + else + auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB); + } + } + + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; +} + + +int main(int argc, char **argv) +{ + int rank, worldSize; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmWorldSize; + MPI_Comm_size(shmcomm, &shmWorldSize); + int nranksPerNode = shmWorldSize; + MPI_Comm_free(&shmcomm); + + test_communicator(rank, worldSize, nranksPerNode); + + MPI_Finalize(); + return 0; +} \ No newline at end of file From 8eda6369ee2b71bfd92e34af021574f7356cfffe Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 06:08:35 +0000 Subject: [PATCH 17/54] testing connection setup --- src/communicator.cc | 4 ++-- src/connection.cc | 10 +++++----- src/ib.cc | 2 +- src/include/connection.hpp | 12 ++++++------ src/include/ib.hpp | 2 +- tests/communicator_test_cpp.cc | 22 +++++++++++++++++++--- 6 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index c24b0c5e7..7e1348e82 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -81,10 +81,10 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank MSCCLPP_API_CPP void Communicator::connectionSetup() { for (auto& conn : pimpl->connections) { - conn->startSetup(*this); + conn->startSetup(pimpl->bootstrap_); } for (auto& conn : pimpl->connections) { - conn->endSetup(*this); + conn->endSetup(pimpl->bootstrap_); } } diff --git a/src/connection.cc b/src/connection.cc index 1e21694cd..fc653c2a4 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -54,7 +54,7 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank(remoteRank), tag(tag), transport_(transport), remoteTransport_(TransportNone) { +IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) { qp = commImpl.getIbContext(transport)->createQp(); } @@ -114,15 +114,15 @@ void IBConnection::flush() { // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } -void IBConnection::startSetup(Communicator& comm) { +void IBConnection::startSetup(std::shared_ptr bootstrap) { // TODO(chhwang): temporarily disabled to compile - // comm.bootstrap().send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank, tag); + bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); } -void IBConnection::endSetup(Communicator& comm) { +void IBConnection::endSetup(std::shared_ptr bootstrap) { IbQpInfo qpInfo; // TODO(chhwang): temporarily disabled to compile - // comm.bootstrap().recv(&qpInfo, sizeof(qpInfo), remoteRank, tag); + bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); qp->rtr(qpInfo); qp->rts(); } diff --git a/src/ib.cc b/src/ib.cc index 4dc0285ba..fe3334a3c 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -263,7 +263,7 @@ int IbQp::pollCq() return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast(this->wcs)); } -const IbQpInfo& IbQp::getInfo() const +IbQpInfo& IbQp::getInfo() { return this->info; } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index dcf21362f..132726f79 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -12,8 +12,8 @@ namespace mscclpp { class ConnectionBase : public Connection { public: - virtual void startSetup(Communicator&) {}; - virtual void endSetup(Communicator&) {}; + virtual void startSetup(std::shared_ptr bootstrap) {}; + virtual void endSetup(std::shared_ptr bootstrap) {}; }; class CudaIpcConnection : public ConnectionBase { @@ -34,8 +34,8 @@ class CudaIpcConnection : public ConnectionBase { }; class IBConnection : public ConnectionBase { - int remoteRank; - int tag; + int remoteRank_; + int tag_; TransportFlags transport_; TransportFlags remoteTransport_; IbQp* qp; @@ -53,9 +53,9 @@ class IBConnection : public ConnectionBase { void flush() override; - void startSetup(Communicator& comm) override; + void startSetup(std::shared_ptr bootstrap) override; - void endSetup(Communicator& comm) override; + void endSetup(std::shared_ptr bootstrap) override; }; } // namespace mscclpp diff --git a/src/include/ib.hpp b/src/include/ib.hpp index d04b75bd2..b1baeb757 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -61,7 +61,7 @@ class IbQp void postRecv(uint64_t wrId); int pollCq(); - const IbQpInfo& getInfo() const; + IbQpInfo& getInfo(); const void* getWc(int idx) const; private: diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index fc3a72e81..d3fe15b00 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -5,6 +5,20 @@ #include #include +mscclpp::TransportFlags findIb(int localRank){ + mscclpp::TransportFlags IBs[] = { + mscclpp::TransportIB0, + mscclpp::TransportIB1, + mscclpp::TransportIB2, + mscclpp::TransportIB3, + mscclpp::TransportIB4, + mscclpp::TransportIB5, + mscclpp::TransportIB6, + mscclpp::TransportIB7 + }; + return IBs[localRank]; +} + void test_communicator(int rank, int worldSize, int nranksPerNode){ auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; @@ -16,12 +30,14 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ auto communicator = std::make_shared(bootstrap); for (int i = 0; i < worldSize; i++){ if (i != rank){ - if (i % nranksPerNode == rank % nranksPerNode) + if (i % nranksPerNode == rank % nranksPerNode){ auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); - else - auto connect = communicator->connect(i, 0, mscclpp::TransportAllIB); + } else { + auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); + } } } + communicator->connectionSetup(); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; From 06c6df2350427fb9b7d955075b95c2705c3326d3 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 27 Apr 2023 19:06:35 +0000 Subject: [PATCH 18/54] Separate out Transport and TransportFlags --- src/communicator.cc | 26 +---- src/connection.cc | 20 ++-- src/ib.cc | 36 +++---- src/include/communicator.hpp | 4 +- src/include/connection.hpp | 14 +-- src/include/mscclpp.hpp | 153 ++++++++++++++++++++++++++---- src/include/registered_memory.hpp | 4 +- src/registered_memory.cc | 36 +++---- tests/communicator_test_cpp.cc | 22 ++--- 9 files changed, 205 insertions(+), 110 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 7e1348e82..1420c51d9 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -18,7 +18,7 @@ Communicator::Impl::~Impl() { ibContexts.clear(); } -IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { +IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { // Find IB context or create it auto it = ibContexts.find(ibTransport); if (it == ibContexts.end()) { @@ -32,24 +32,6 @@ IbCtx* Communicator::Impl::getIbContext(TransportFlags ibTransport) { MSCCLPP_API_CPP Communicator::~Communicator() = default; -static mscclppTransport_t transportToCStyle(TransportFlags flags) { - switch (flags) { - case TransportIB0: - case TransportIB1: - case TransportIB2: - case TransportIB3: - case TransportIB4: - case TransportIB5: - case TransportIB6: - case TransportIB7: - return mscclppTransportIB; - case TransportCudaIpc: - return mscclppTransportP2P; - default: - throw std::runtime_error("Unsupported conversion"); - } -} - MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) : pimpl(std::make_unique(bootstrap)) {} MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { @@ -64,12 +46,12 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF return RegisteredMemory(std::make_shared(ptr, size, pimpl->comm->rank, transports, *pimpl)); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) { std::shared_ptr conn; - if (transport | TransportCudaIpc) { + if (transport == Transport::CudaIpc) { auto cudaIpcConn = std::make_shared(); conn = cudaIpcConn; - } else if (transport | TransportAllIB) { + } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); conn = ibConn; } else { diff --git a/src/connection.cc b/src/connection.cc index fc653c2a4..031f63ec5 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -6,8 +6,8 @@ namespace mscclpp { -void validateTransport(RegisteredMemory mem, TransportFlags transport) { - if ((mem.transports() & transport) == TransportNone) { +void validateTransport(RegisteredMemory mem, Transport transport) { + if (!mem.transports().has(transport)) { throw std::runtime_error("mem does not support transport"); } } @@ -28,12 +28,12 @@ CudaIpcConnection::~CudaIpcConnection() { cudaStreamDestroy(stream); } -TransportFlags CudaIpcConnection::transport() { - return TransportCudaIpc; +Transport CudaIpcConnection::transport() { + return Transport::CudaIpc; } -TransportFlags CudaIpcConnection::remoteTransport() { - return TransportCudaIpc; +Transport CudaIpcConnection::remoteTransport() { + return Transport::CudaIpc; } void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { @@ -54,7 +54,7 @@ void CudaIpcConnection::flush() { // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(TransportNone) { +IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) { qp = commImpl.getIbContext(transport)->createQp(); } @@ -62,11 +62,11 @@ IBConnection::~IBConnection() { // TODO: Destroy QP? } -TransportFlags IBConnection::transport() { +Transport IBConnection::transport() { return transport_; } -TransportFlags IBConnection::remoteTransport() { +Transport IBConnection::remoteTransport() { return remoteTransport_; } @@ -115,13 +115,11 @@ void IBConnection::flush() { } void IBConnection::startSetup(std::shared_ptr bootstrap) { - // TODO(chhwang): temporarily disabled to compile bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); } void IBConnection::endSetup(std::shared_ptr bootstrap) { IbQpInfo qpInfo; - // TODO(chhwang): temporarily disabled to compile bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); qp->rtr(qpInfo); qp->rts(); diff --git a/src/ib.cc b/src/ib.cc index fe3334a3c..88d14d8ef 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -368,33 +368,33 @@ int getIBDeviceCount() { return num; } -std::string getIBDeviceName(TransportFlags ibTransport) { +std::string getIBDeviceName(Transport ibTransport) { int num; struct ibv_device** devices = ibv_get_device_list(&num); int ibTransportIndex; switch (ibTransport) { // TODO: get rid of this ugly switch - case TransportIB0: + case Transport::IB0: ibTransportIndex = 0; break; - case TransportIB1: + case Transport::IB1: ibTransportIndex = 1; break; - case TransportIB2: + case Transport::IB2: ibTransportIndex = 2; break; - case TransportIB3: + case Transport::IB3: ibTransportIndex = 3; break; - case TransportIB4: + case Transport::IB4: ibTransportIndex = 4; break; - case TransportIB5: + case Transport::IB5: ibTransportIndex = 5; break; - case TransportIB6: + case Transport::IB6: ibTransportIndex = 6; break; - case TransportIB7: + case Transport::IB7: ibTransportIndex = 7; break; default: @@ -406,28 +406,28 @@ std::string getIBDeviceName(TransportFlags ibTransport) { return devices[ibTransportIndex]->name; } -TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName) { +Transport getIBTransportByDeviceName(const std::string& ibDeviceName) { int num; struct ibv_device** devices = ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { if (ibDeviceName == devices[i]->name) { switch (i) { // TODO: get rid of this ugly switch case 0: - return TransportIB0; + return Transport::IB0; case 1: - return TransportIB1; + return Transport::IB1; case 2: - return TransportIB2; + return Transport::IB2; case 3: - return TransportIB3; + return Transport::IB3; case 4: - return TransportIB4; + return Transport::IB4; case 5: - return TransportIB5; + return Transport::IB5; case 6: - return TransportIB6; + return Transport::IB6; case 7: - return TransportIB7; + return Transport::IB7; default: throw std::runtime_error("IB device index out of range"); } diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 53d0fd73f..8ca4e952e 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -16,14 +16,14 @@ class ConnectionBase; struct Communicator::Impl { mscclppComm_t comm; std::vector> connections; - std::unordered_map> ibContexts; + std::unordered_map> ibContexts; std::shared_ptr bootstrap_; Impl(std::shared_ptr bootstrap); ~Impl(); - IbCtx* getIbContext(TransportFlags ibTransport); + IbCtx* getIbContext(Transport ibTransport); }; } // namespace mscclpp diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 132726f79..bd08802c1 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -24,9 +24,9 @@ class CudaIpcConnection : public ConnectionBase { ~CudaIpcConnection(); - TransportFlags transport() override; + Transport transport() override; - TransportFlags remoteTransport() override; + Transport remoteTransport() override; void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; @@ -36,18 +36,18 @@ class CudaIpcConnection : public ConnectionBase { class IBConnection : public ConnectionBase { int remoteRank_; int tag_; - TransportFlags transport_; - TransportFlags remoteTransport_; + Transport transport_; + Transport remoteTransport_; IbQp* qp; public: - IBConnection(int remoteRank, int tag, TransportFlags transport, Communicator::Impl& commImpl); + IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); ~IBConnection(); - TransportFlags transport() override; + Transport transport() override; - TransportFlags remoteTransport() override; + Transport remoteTransport() override; void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index f14e19c19..3b9c6d8d5 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace mscclpp { @@ -63,24 +64,129 @@ class Bootstrap : public BaseBootstrap */ std::unique_ptr getUniqueId(); -using TransportFlags = uint32_t; -const TransportFlags TransportNone = 0b0; -const TransportFlags TransportCudaIpc = 0b1; -const TransportFlags TransportIB0 = 0b10; -const TransportFlags TransportIB1 = 0b100; -const TransportFlags TransportIB2 = 0b1000; -const TransportFlags TransportIB3 = 0b10000; -const TransportFlags TransportIB4 = 0b100000; -const TransportFlags TransportIB5 = 0b1000000; -const TransportFlags TransportIB6 = 0b10000000; -const TransportFlags TransportIB7 = 0b100000000; - -const TransportFlags TransportAll = 0b111111111; -const TransportFlags TransportAllIB = 0b111111110; +enum class Transport { + Unknown, + CudaIpc, + IB0, + IB1, + IB2, + IB3, + IB4, + IB5, + IB6, + IB7, + NumTransports +}; + +namespace detail { + const size_t TransportFlagsSize = 10; + static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), "TransportFlagsSize must match the number of transports"); + using TransportFlagsBase = std::bitset; +} + +class TransportFlags : private detail::TransportFlagsBase { +public: + TransportFlags() = default; + TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast(transport)) {} + + bool has(Transport transport) const { + return detail::TransportFlagsBase::test(static_cast(transport)); + } + + bool none() const { + return detail::TransportFlagsBase::none(); + } + + bool any() const { + return detail::TransportFlagsBase::any(); + } + + bool all() const { + return detail::TransportFlagsBase::all(); + } + + size_t count() const { + return detail::TransportFlagsBase::count(); + } + + TransportFlags& operator|=(TransportFlags other) { + detail::TransportFlagsBase::operator|=(other); + return *this; + } + + TransportFlags operator|(TransportFlags other) const { + return TransportFlags(*this) |= other; + } + + TransportFlags operator|(Transport transport) const { + return *this | TransportFlags(transport); + } + + TransportFlags& operator&=(TransportFlags other) { + detail::TransportFlagsBase::operator&=(other); + return *this; + } + + TransportFlags operator&(TransportFlags other) const { + return TransportFlags(*this) &= other; + } + + TransportFlags operator&(Transport transport) const { + return *this & TransportFlags(transport); + } + + TransportFlags& operator^=(TransportFlags other) { + detail::TransportFlagsBase::operator^=(other); + return *this; + } + + TransportFlags operator^(TransportFlags other) const { + return TransportFlags(*this) ^= other; + } + + TransportFlags operator^(Transport transport) const { + return *this ^ TransportFlags(transport); + } + + TransportFlags operator~() const { + return TransportFlags(*this).flip(); + } + + bool operator==(TransportFlags other) const { + return detail::TransportFlagsBase::operator==(other); + } + + bool operator!=(TransportFlags other) const { + return detail::TransportFlagsBase::operator!=(other); + } + + detail::TransportFlagsBase toBitset() const { + return *this; + } + +private: + TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {} +}; + +inline TransportFlags operator|(Transport transport1, Transport transport2) { + return TransportFlags(transport1) | transport2; +} + +inline TransportFlags operator&(Transport transport1, Transport transport2) { + return TransportFlags(transport1) & transport2; +} + +inline TransportFlags operator^(Transport transport1, Transport transport2) { + return TransportFlags(transport1) ^ transport2; +} + +const TransportFlags NoTransports = TransportFlags(); +const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7; +const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc; int getIBDeviceCount(); -std::string getIBDeviceName(TransportFlags ibTransport); -TransportFlags getIBTransportByDeviceName(const std::string& ibDeviceName); +std::string getIBDeviceName(Transport ibTransport); +Transport getIBTransportByDeviceName(const std::string& ibDeviceName); class Communicator; class Connection; @@ -111,9 +217,9 @@ class Connection { virtual void flush() = 0; - virtual TransportFlags transport() = 0; + virtual Transport transport() = 0; - virtual TransportFlags remoteTransport() = 0; + virtual Transport remoteTransport() = 0; protected: static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); @@ -166,7 +272,7 @@ class Communicator { * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. */ - std::shared_ptr connect(int remoteRank, int tag, TransportFlags transport); + std::shared_ptr connect(int remoteRank, int tag, Transport transport); /* Establish all connections declared by connect(). This function must be called after all connect() * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. @@ -180,4 +286,13 @@ class Communicator { } // namespace mscclpp +namespace std { + template <> + struct hash { + size_t operator()(const mscclpp::TransportFlags& flags) const { + return hash()(flags.toBitset()); + } + }; +} + #endif // MSCCLPP_H_ diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index d2270d468..afe42da45 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -10,7 +10,7 @@ namespace mscclpp { struct TransportInfo { - TransportFlags transport; + Transport transport; // TODO: rewrite this using std::variant or something bool ibLocal; @@ -31,7 +31,7 @@ struct RegisteredMemory::Impl { Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(const std::vector& data); - TransportInfo& getTransportInfo(TransportFlags transport) { + TransportInfo& getTransportInfo(Transport transport) { for (auto& entry : transportInfos) { if (entry.transport == transport) { return entry; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index f0db85ce4..b26ea2d54 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -5,17 +5,17 @@ namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) { - if (transports & TransportCudaIpc) { + if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; - transportInfo.transport = TransportCudaIpc; + transportInfo.transport = Transport::CudaIpc; cudaIpcMemHandle_t handle; // TODO: translate data to a base pointer CUDATHROW(cudaIpcGetMemHandle(&handle, data)); transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); } - if (transports & TransportAllIB) { - auto addIb = [&](TransportFlags ibTransport) { + if ((transports & AllIBTransports).any()) { + auto addIb = [&](Transport ibTransport) { TransportInfo transportInfo; transportInfo.transport = ibTransport; const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); @@ -23,14 +23,14 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); }; - if (transports & TransportIB0) addIb(TransportIB0); - if (transports & TransportIB1) addIb(TransportIB1); - if (transports & TransportIB2) addIb(TransportIB2); - if (transports & TransportIB3) addIb(TransportIB3); - if (transports & TransportIB4) addIb(TransportIB4); - if (transports & TransportIB5) addIb(TransportIB5); - if (transports & TransportIB6) addIb(TransportIB6); - if (transports & TransportIB7) addIb(TransportIB7); + if (transports.has(Transport::IB0)) addIb(Transport::IB0); + if (transports.has(Transport::IB1)) addIb(Transport::IB1); + if (transports.has(Transport::IB2)) addIb(Transport::IB2); + if (transports.has(Transport::IB3)) addIb(Transport::IB3); + if (transports.has(Transport::IB4)) addIb(Transport::IB4); + if (transports.has(Transport::IB5)) addIb(Transport::IB5); + if (transports.has(Transport::IB6)) addIb(Transport::IB6); + if (transports.has(Transport::IB7)) addIb(Transport::IB7); } } @@ -66,9 +66,9 @@ std::vector RegisteredMemory::serialize() { std::copy_n(reinterpret_cast(&transportCount), sizeof(transportCount), std::back_inserter(result)); for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); - if (entry.transport == TransportCudaIpc) { + if (entry.transport == Transport::CudaIpc) { std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result)); - } else if (entry.transport & TransportAllIB) { + } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); } else { throw std::runtime_error("Unknown transport"); @@ -96,12 +96,12 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { TransportInfo transportInfo; std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); - if (transportInfo.transport & TransportCudaIpc) { + if (transportInfo.transport == Transport::CudaIpc) { cudaIpcMemHandle_t handle; std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); it += sizeof(handle); transportInfo.cudaIpcHandle = handle; - } else if (transportInfo.transport & TransportAllIB) { + } else if (AllIBTransports.has(transportInfo.transport)) { IbMrInfo info; std::copy_n(it, sizeof(info), reinterpret_cast(&info)); it += sizeof(info); @@ -116,8 +116,8 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { throw std::runtime_error("Deserialization failed"); } - if (transports & TransportCudaIpc) { - auto entry = getTransportInfo(TransportCudaIpc); + if (transports.has(Transport::CudaIpc)) { + auto entry = getTransportInfo(Transport::CudaIpc); CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); } } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index d3fe15b00..9ca469888 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -5,16 +5,16 @@ #include #include -mscclpp::TransportFlags findIb(int localRank){ - mscclpp::TransportFlags IBs[] = { - mscclpp::TransportIB0, - mscclpp::TransportIB1, - mscclpp::TransportIB2, - mscclpp::TransportIB3, - mscclpp::TransportIB4, - mscclpp::TransportIB5, - mscclpp::TransportIB6, - mscclpp::TransportIB7 +mscclpp::Transport findIb(int localRank){ + mscclpp::Transport IBs[] = { + mscclpp::Transport::IB0, + mscclpp::Transport::IB1, + mscclpp::Transport::IB2, + mscclpp::Transport::IB3, + mscclpp::Transport::IB4, + mscclpp::Transport::IB5, + mscclpp::Transport::IB6, + mscclpp::Transport::IB7 }; return IBs[localRank]; } @@ -31,7 +31,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ for (int i = 0; i < worldSize; i++){ if (i != rank){ if (i % nranksPerNode == rank % nranksPerNode){ - auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); + auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); } From aaa3f0e94521c5c3ec24a67b6a39e6aa31c71917 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 19:17:19 +0000 Subject: [PATCH 19/54] host hashes in communicator --- src/communicator.cc | 20 +++++++++++++++++++- src/include/communicator.hpp | 1 + tests/communicator_test_cpp.cc | 3 ++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 7e1348e82..6f458fe56 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -1,3 +1,5 @@ +#include + #include "mscclpp.hpp" #include "communicator.hpp" #include "host_connection.hpp" @@ -12,7 +14,13 @@ namespace mscclpp { -Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) {} +Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) { + rankToHash_.resize(bootstrap->getNranks()); + auto hostHash = getHostHash(); + INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); + rankToHash_[bootstrap->getRank()] = hostHash; + bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); +} Communicator::Impl::~Impl() { ibContexts.clear(); @@ -67,11 +75,21 @@ RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportF MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, TransportFlags transport) { std::shared_ptr conn; if (transport | TransportCudaIpc) { + // sanity check: make sure the IPC connection is being made within a node + if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) { + std::stringstream ss; + ss << "Cuda IPC connection can only be made within a node: " << remoteRank << " != " << pimpl->bootstrap_->getRank(); + throw std::runtime_error(ss.str()); + } auto cudaIpcConn = std::make_shared(); conn = cudaIpcConn; + INFO(MSCCLPP_INIT, "Cuda IPC connection between %d(%lx) and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], + remoteRank, pimpl->rankToHash_[remoteRank]); } else if (transport | TransportAllIB) { auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); conn = ibConn; + INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], + getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); } else { throw std::runtime_error("Unsupported transport"); } diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 53d0fd73f..5be00a67c 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -18,6 +18,7 @@ struct Communicator::Impl { std::vector> connections; std::unordered_map> ibContexts; std::shared_ptr bootstrap_; + std::vector rankToHash_; Impl(std::shared_ptr bootstrap); diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index d3fe15b00..05595313e 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -30,7 +30,8 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ auto communicator = std::make_shared(bootstrap); for (int i = 0; i < worldSize; i++){ if (i != rank){ - if (i % nranksPerNode == rank % nranksPerNode){ + if (i / nranksPerNode == rank / nranksPerNode){ + printf("i %d rank %d nranksPerNode %d\n", i, rank, nranksPerNode); auto connect = communicator->connect(i, 0, mscclpp::TransportCudaIpc); } else { auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); From afc5887da20a24a0d0ec03a2be6f309e12a44d54 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 20:32:06 +0000 Subject: [PATCH 20/54] moving the debug info into other levels --- src/communicator.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/communicator.cc b/src/communicator.cc index 726efbc8b..bdccf8ebf 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -66,10 +66,12 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank } auto cudaIpcConn = std::make_shared(); conn = cudaIpcConn; + INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], + remoteRank, pimpl->rankToHash_[remoteRank]); } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); conn = ibConn; - INFO(MSCCLPP_INIT, "IB connection between %d(%lx) via %s and %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], + INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); } else { throw std::runtime_error("Unsupported transport"); From 82c27625e604c7ccd3d138adefddf1778b0e0e09 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 21:33:15 +0000 Subject: [PATCH 21/54] ipc uses a base ptr now --- Makefile | 2 +- src/basic_proxy_handler.cc | 8 +- src/bootstrap/bootstrap.cc | 10 +- src/communicator.cc | 61 ++++++---- src/connection.cc | 60 ++++++---- src/epoch.cc | 13 +- src/fifo.cc | 29 +++-- src/host_connection.cc | 55 ++++++--- src/ib.cc | 131 +++++++++++--------- src/include/basic_proxy_handler.hpp | 4 +- src/include/channel.hpp | 70 +++++++---- src/include/checks.hpp | 11 ++ src/include/comm.h | 4 +- src/include/communicator.hpp | 11 +- src/include/connection.hpp | 27 +++-- src/include/epoch.hpp | 11 +- src/include/host_connection.hpp | 7 +- src/include/ib.hpp | 8 +- src/include/mscclpp.h | 10 +- src/include/mscclpp.hpp | 180 +++++++++++++++++----------- src/include/mscclppfifo.hpp | 25 ++-- src/include/proxy.h | 2 +- src/include/proxy.hpp | 10 +- src/include/registered_memory.hpp | 15 ++- src/include/registered_ptr.hpp | 32 +++-- src/init.cc | 47 +++++--- src/proxy_cpp.cc | 28 +++-- src/registered_memory.cc | 60 +++++++--- tests/allgather_test_cpp.cu | 43 +++---- tests/bootstrap_test_cpp.cc | 52 +++++--- tests/communicator_test_cpp.cc | 32 ++--- tests/unittests/ib_test.cc | 2 +- 32 files changed, 649 insertions(+), 411 deletions(-) diff --git a/Makefile b/Makefile index 950751d79..41896041f 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ endif NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11 --expt-extended-lambda -Xfatbin -compress-all # Use addprefix so that we can specify more than one path -NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt +NVLDFLAGS := -L$(CUDA_LIB) -lcudart -lrt -lcuda ifeq ($(DEBUG), 0) NVCUFLAGS += -O3 diff --git a/src/basic_proxy_handler.cc b/src/basic_proxy_handler.cc index 482aa8422..424701315 100644 --- a/src/basic_proxy_handler.cc +++ b/src/basic_proxy_handler.cc @@ -2,15 +2,17 @@ namespace mscclpp { -ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm) { +ProxyHandler makeBasicProxyHandler(Communicator::Impl& comm) +{ return [&comm](ProxyTrigger triggerRaw) { - ChannelTrigger *trigger = reinterpret_cast(&triggerRaw); + ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); HostConnection& conn = *comm.connections.at(trigger->fields.connId); auto result = ProxyHandlerResult::Continue; if (trigger->fields.type & mscclppData) { - conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, trigger->fields.srcOffset, trigger->fields.size); + conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, + trigger->fields.srcOffset, trigger->fields.size); } if (trigger->fields.type & mscclppFlag) { diff --git a/src/bootstrap/bootstrap.cc b/src/bootstrap/bootstrap.cc index dfce50b4c..752257998 100644 --- a/src/bootstrap/bootstrap.cc +++ b/src/bootstrap/bootstrap.cc @@ -180,9 +180,8 @@ Bootstrap::Impl::~Impl() } } -void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, - std::vector& rankAddresses, - std::vector& rankAddressesRoot, int& rank) +void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, std::vector& rankAddresses, + std::vector& rankAddressesRoot, int& rank) { mscclppSocket sock; ExtInfo info; @@ -211,7 +210,7 @@ void Bootstrap::Impl::getRemoteAddresses(mscclppSocket* listenSock, } void Bootstrap::Impl::sendHandleToPeer(int peer, const std::vector& rankAddresses, - const std::vector& rankAddressesRoot) + const std::vector& rankAddressesRoot) { mscclppSocket sock; int next = (peer + 1) % this->nRanks_; @@ -226,7 +225,8 @@ void Bootstrap::Impl::bootstrapCreateRoot() mscclppSocket listenSock; // mscclppSocket* listenSock = new mscclppSocket(); // TODO(saemal) make this a shared ptr - MSCCLPPTHROW(mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0)); + MSCCLPPTHROW( + mscclppSocketInit(&listenSock, &uniqueId_.addr, uniqueId_.magic, mscclppSocketTypeBootstrap, nullptr, 0)); MSCCLPPTHROW(mscclppSocketListen(&listenSock)); MSCCLPPTHROW(mscclppSocketGetAddr(&listenSock, &uniqueId_.addr)); auto lambda = [this, listenSock]() { this->bootstrapRoot(listenSock); }; diff --git a/src/communicator.cc b/src/communicator.cc index bdccf8ebf..78df252d0 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -1,20 +1,21 @@ #include -#include "mscclpp.hpp" -#include "communicator.hpp" -#include "host_connection.hpp" -#include "comm.h" -#include "basic_proxy_handler.hpp" #include "api.h" -#include "utils.h" +#include "basic_proxy_handler.hpp" #include "checks.hpp" -#include "debug.h" +#include "comm.h" +#include "communicator.hpp" #include "connection.hpp" +#include "debug.h" +#include "host_connection.hpp" +#include "mscclpp.hpp" #include "registered_memory.hpp" +#include "utils.h" namespace mscclpp { -Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) { +Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_(bootstrap) +{ rankToHash_.resize(bootstrap->getNranks()); auto hostHash = getHostHash(); INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); @@ -22,11 +23,13 @@ Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_( bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); } -Communicator::Impl::~Impl() { +Communicator::Impl::~Impl() +{ ibContexts.clear(); } -IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { +IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) +{ // Find IB context or create it auto it = ibContexts.find(ibTransport); if (it == ibContexts.end()) { @@ -40,39 +43,50 @@ IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { MSCCLPP_API_CPP Communicator::~Communicator() = default; -MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) : pimpl(std::make_unique(bootstrap)) {} +MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootstrap) + : pimpl(std::make_unique(bootstrap)) +{ +} -MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) { +MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) +{ mscclppBootstrapAllGather(pimpl->comm, data, size); } -MSCCLPP_API_CPP void Communicator::bootstrapBarrier() { +MSCCLPP_API_CPP void Communicator::bootstrapBarrier() +{ mscclppBootstrapBarrier(pimpl->comm); } -RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { +RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) +{ return RegisteredMemory(std::make_shared(ptr, size, pimpl->comm->rank, transports, *pimpl)); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) +{ std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node if (pimpl->rankToHash_[remoteRank] != pimpl->rankToHash_[pimpl->bootstrap_->getRank()]) { std::stringstream ss; - ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" << " != " - << pimpl->bootstrap_->getRank() << "(" << std::hex << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; + ss << "Cuda IPC connection can only be made within a node: " << remoteRank << "(" << std::hex + << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")" + << " != " << pimpl->bootstrap_->getRank() << "(" << std::hex + << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; throw std::runtime_error(ss.str()); - } + } auto cudaIpcConn = std::make_shared(); conn = cudaIpcConn; - INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], - remoteRank, pimpl->rankToHash_[remoteRank]); + INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created", + pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank, + pimpl->rankToHash_[remoteRank]); } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, *pimpl); conn = ibConn; - INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], - getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); + INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", + pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], + getIBDeviceName(transport).c_str(), remoteRank, pimpl->rankToHash_[remoteRank]); } else { throw std::runtime_error("Unsupported transport"); } @@ -80,7 +94,8 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank return conn; } -MSCCLPP_API_CPP void Communicator::connectionSetup() { +MSCCLPP_API_CPP void Communicator::connectionSetup() +{ for (auto& conn : pimpl->connections) { conn->startSetup(pimpl->bootstrap_); } diff --git a/src/connection.cc b/src/connection.cc index 031f63ec5..75a6ba797 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,12 +1,13 @@ #include "connection.hpp" #include "checks.hpp" -#include "registered_memory.hpp" -#include "npkit/npkit.h" #include "infiniband/verbs.h" +#include "npkit/npkit.h" +#include "registered_memory.hpp" namespace mscclpp { -void validateTransport(RegisteredMemory mem, Transport transport) { +void validateTransport(RegisteredMemory mem, Transport transport) +{ if (!mem.transports().has(transport)) { throw std::runtime_error("mem does not support transport"); } @@ -14,29 +15,36 @@ void validateTransport(RegisteredMemory mem, Transport transport) { // Connection -std::shared_ptr Connection::getRegisteredMemoryImpl(RegisteredMemory& mem) { +std::shared_ptr Connection::getRegisteredMemoryImpl(RegisteredMemory& mem) +{ return mem.pimpl; } // CudaIpcConnection -CudaIpcConnection::CudaIpcConnection() { +CudaIpcConnection::CudaIpcConnection() +{ cudaStreamCreate(&stream); } -CudaIpcConnection::~CudaIpcConnection() { +CudaIpcConnection::~CudaIpcConnection() +{ cudaStreamDestroy(stream); } -Transport CudaIpcConnection::transport() { +Transport CudaIpcConnection::transport() +{ return Transport::CudaIpc; } -Transport CudaIpcConnection::remoteTransport() { +Transport CudaIpcConnection::remoteTransport() +{ return Transport::CudaIpc; } -void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { +void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) +{ validateTransport(dst, remoteTransport()); validateTransport(src, transport()); @@ -47,30 +55,38 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); } -void CudaIpcConnection::flush() { +void CudaIpcConnection::flush() +{ CUDATHROW(cudaStreamSynchronize(stream)); // npkitCollectExitEvents(conn, NPKIT_EVENT_DMA_SEND_EXIT); } // IBConnection -IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) { +IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) + : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) +{ qp = commImpl.getIbContext(transport)->createQp(); } -IBConnection::~IBConnection() { +IBConnection::~IBConnection() +{ // TODO: Destroy QP? } -Transport IBConnection::transport() { +Transport IBConnection::transport() +{ return transport_; } -Transport IBConnection::remoteTransport() { +Transport IBConnection::remoteTransport() +{ return remoteTransport_; } -void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { +void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) +{ validateTransport(dst, remoteTransport()); validateTransport(src, transport()); @@ -82,16 +98,18 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem if (!srcTransportInfo.ibLocal) { throw std::runtime_error("src is remote, which is not supported"); } - + auto dstMrInfo = dstTransportInfo.ibMrInfo; auto srcMr = srcTransportInfo.ibMr; - qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/false); + qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, + /*signaled=*/false); qp->postSend(); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } -void IBConnection::flush() { +void IBConnection::flush() +{ bool isWaiting = true; while (isWaiting) { int wcNum = qp->pollCq(); @@ -114,11 +132,13 @@ void IBConnection::flush() { // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } -void IBConnection::startSetup(std::shared_ptr bootstrap) { +void IBConnection::startSetup(std::shared_ptr bootstrap) +{ bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); } -void IBConnection::endSetup(std::shared_ptr bootstrap) { +void IBConnection::endSetup(std::shared_ptr bootstrap) +{ IbQpInfo qpInfo; bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); qp->rtr(qpInfo); diff --git a/src/epoch.cc b/src/epoch.cc index 1fee307ea..f6c827311 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -3,20 +3,25 @@ namespace mscclpp { -struct Epoch::Impl { +struct Epoch::Impl +{ DeviceEpoch deviceEpoch; - Impl() { + Impl() + { MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.localSignalEpochId, 1)); MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.waitEpochId, 1)); } - ~Impl() { + ~Impl() + { MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.localSignalEpochId)); MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.waitEpochId)); } }; -Epoch::Epoch() : pimpl(std::make_unique()) {} +Epoch::Epoch() : pimpl(std::make_unique()) +{ +} } // namespace mscclpp \ No newline at end of file diff --git a/src/fifo.cc b/src/fifo.cc index fe7f12d3a..c2fdd7385 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -1,13 +1,14 @@ -#include "mscclppfifo.hpp" #include "alloc.h" #include "checks.hpp" +#include "mscclppfifo.hpp" #include -#include #include +#include namespace mscclpp { -struct HostProxyFifo::Impl { +struct HostProxyFifo::Impl +{ DeviceProxyFifo deviceFifo; // allocated on the host. Only accessed by the host. This is a copy of the @@ -23,7 +24,8 @@ struct HostProxyFifo::Impl { cudaStream_t stream; }; -HostProxyFifo::HostProxyFifo() { +HostProxyFifo::HostProxyFifo() +{ pimpl = std::make_unique(); MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1)); MSCCLPPTHROW(mscclppCudaHostCalloc(&pimpl->deviceFifo.triggers, MSCCLPP_PROXY_FIFO_SIZE)); @@ -32,35 +34,40 @@ HostProxyFifo::HostProxyFifo() { pimpl->hostTail = 0; } -HostProxyFifo::~HostProxyFifo() { +HostProxyFifo::~HostProxyFifo() +{ MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head)); MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers)); MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica)); CUDATHROW(cudaStreamDestroy(pimpl->stream)); } -void HostProxyFifo::poll(ProxyTrigger *trigger) { +void HostProxyFifo::poll(ProxyTrigger* trigger) +{ __m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]); _mm_store_si128((__m128i*)trigger, xmm0); } -void HostProxyFifo::pop() { +void HostProxyFifo::pop() +{ *(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0; (pimpl->hostTail)++; } -void HostProxyFifo::flushTail(bool sync) { +void HostProxyFifo::flushTail(bool sync) +{ // Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush // request. - CUDATHROW( - cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, pimpl->stream)); + CUDATHROW(cudaMemcpyAsync(pimpl->deviceFifo.tailReplica, &pimpl->hostTail, sizeof(uint64_t), cudaMemcpyHostToDevice, + pimpl->stream)); if (sync) { CUDATHROW(cudaStreamSynchronize(pimpl->stream)); } } -DeviceProxyFifo HostProxyFifo::toDevice() { +DeviceProxyFifo HostProxyFifo::toDevice() +{ return pimpl->deviceFifo; } diff --git a/src/host_connection.cc b/src/host_connection.cc index 72e11ffc5..e33069e25 100644 --- a/src/host_connection.cc +++ b/src/host_connection.cc @@ -1,52 +1,64 @@ #include "host_connection.hpp" -#include "communicator.hpp" +#include "api.h" #include "comm.h" +#include "communicator.hpp" #include "mscclpp.h" #include "mscclppfifo.h" -#include "api.h" namespace mscclpp { -HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) { +HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) +{ this->hostConn = conn->hostConn; } -HostConnection::Impl::~Impl() { +HostConnection::Impl::~Impl() +{ // TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not. } MSCCLPP_API_CPP HostConnection::~HostConnection() = default; -MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr p) : pimpl(std::move(p)) {} +MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr p) : pimpl(std::move(p)) +{ +} -MSCCLPP_API_CPP int HostConnection::getId() { +MSCCLPP_API_CPP int HostConnection::getId() +{ return pimpl->conn->connId; } -MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) { +MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) +{ BufferHandle result; static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t)); - mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, reinterpret_cast(&result)); + mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, + reinterpret_cast(&result)); return result; } -MSCCLPP_API_CPP int HostConnection::numLocalBuffers() { +MSCCLPP_API_CPP int HostConnection::numLocalBuffers() +{ return pimpl->conn->bufferRegistrations.size() - 1; } -MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) { +MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) +{ return index + 1; } -MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() { +MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() +{ return pimpl->conn->remoteBufferRegistrations.size() - 1; } -MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) { +MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) +{ return index + 1; } -MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() { +MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() +{ ConnectionEpoch epoch; static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId)); epoch.localSignalEpochId = reinterpret_cast(pimpl->conn->devConn->localSignalEpochId); @@ -55,24 +67,29 @@ MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() { return epoch; } - -MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() { +MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() +{ return pimpl->comm->pimpl->proxy.fifo().toDevice(); } -MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) { +MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, + uint64_t size) +{ pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size); } -MSCCLPP_API_CPP void HostConnection::signal() { +MSCCLPP_API_CPP void HostConnection::signal() +{ pimpl->hostConn->signal(); } -MSCCLPP_API_CPP void HostConnection::flush() { +MSCCLPP_API_CPP void HostConnection::flush() +{ pimpl->hostConn->flush(); } -MSCCLPP_API_CPP void HostConnection::wait() { +MSCCLPP_API_CPP void HostConnection::wait() +{ pimpl->hostConn->wait(); } diff --git a/src/ib.cc b/src/ib.cc index 88d14d8ef..ec7e95f25 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -1,16 +1,16 @@ #include #include #include -#include #include +#include #include -#include "mscclpp.hpp" #include "alloc.h" +#include "checks.hpp" #include "comm.h" #include "debug.h" #include "ib.hpp" -#include "checks.hpp" +#include "mscclpp.hpp" #include #include @@ -28,7 +28,9 @@ IbMr::IbMr(void* pd, void* buff, std::size_t size) : buff(buff) uintptr_t addr = reinterpret_cast(buff) & -pageSize; std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; struct ibv_pd* _pd = reinterpret_cast(pd); - struct ibv_mr* _mr = ibv_reg_mr(_pd, reinterpret_cast(addr), pages * pageSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); + struct ibv_mr* _mr = + ibv_reg_mr(_pd, reinterpret_cast(addr), pages * pageSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING); if (_mr == nullptr) { std::stringstream err; err << "ibv_reg_mr failed (errno " << errno << ")"; @@ -164,7 +166,9 @@ void IbQp::rtr(const IbQpInfo& info) qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; qp_attr.ah_attr.port_num = info.port; - int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); if (ret != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; @@ -182,7 +186,9 @@ void IbQp::rts() qp_attr.rnr_retry = 7; qp_attr.sq_psn = 0; qp_attr.max_rd_atomic = 1; - int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); + int ret = ibv_modify_qp(reinterpret_cast(this->qp), &qp_attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC); if (ret != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; @@ -190,7 +196,8 @@ void IbQp::rts() } } -int IbQp::stageSend(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) +int IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled) { if (this->wrn >= MSCCLPP_IB_MAX_SENDS) { return -1; @@ -219,7 +226,8 @@ int IbQp::stageSend(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_ return this->wrn; } -int IbQp::stageSendWithImm(const IbMr *mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) +int IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled, unsigned int immData) { int wrn = this->stageSend(mr, info, size, wrId, srcOffset, dstOffset, signaled); struct ibv_send_wr* wrs_ = reinterpret_cast(this->wrs); @@ -234,7 +242,8 @@ void IbQp::postSend() return; } struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(reinterpret_cast(this->qp), reinterpret_cast(this->wrs), &bad_wr); + int ret = ibv_post_send(reinterpret_cast(this->qp), reinterpret_cast(this->wrs), + &bad_wr); if (ret != 0) { std::stringstream err; err << "ibv_post_send failed (errno " << errno << ")"; @@ -260,7 +269,8 @@ void IbQp::postRecv(uint64_t wrId) int IbQp::pollCq() { - return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, reinterpret_cast(this->wcs)); + return ibv_poll_cq(reinterpret_cast(this->cq), MSCCLPP_IB_CQ_POLL_NUM, + reinterpret_cast(this->wcs)); } IbQpInfo& IbQp::getInfo() @@ -317,8 +327,8 @@ bool IbCtx::isPortUsable(int port) const err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; throw std::runtime_error(err.str()); } - return portAttr.state == IBV_PORT_ACTIVE && (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || - portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND); + return portAttr.state == IBV_PORT_ACTIVE && + (portAttr.link_layer == IBV_LINK_LAYER_ETHERNET || portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND); } int IbCtx::getAnyActivePort() const @@ -362,43 +372,45 @@ const std::string& IbCtx::getDevName() const return this->devName; } -int getIBDeviceCount() { +int getIBDeviceCount() +{ int num; ibv_get_device_list(&num); return num; } -std::string getIBDeviceName(Transport ibTransport) { +std::string getIBDeviceName(Transport ibTransport) +{ int num; struct ibv_device** devices = ibv_get_device_list(&num); int ibTransportIndex; switch (ibTransport) { // TODO: get rid of this ugly switch - case Transport::IB0: - ibTransportIndex = 0; - break; - case Transport::IB1: - ibTransportIndex = 1; - break; - case Transport::IB2: - ibTransportIndex = 2; - break; - case Transport::IB3: - ibTransportIndex = 3; - break; - case Transport::IB4: - ibTransportIndex = 4; - break; - case Transport::IB5: - ibTransportIndex = 5; - break; - case Transport::IB6: - ibTransportIndex = 6; - break; - case Transport::IB7: - ibTransportIndex = 7; - break; - default: - throw std::runtime_error("Not an IB transport"); + case Transport::IB0: + ibTransportIndex = 0; + break; + case Transport::IB1: + ibTransportIndex = 1; + break; + case Transport::IB2: + ibTransportIndex = 2; + break; + case Transport::IB3: + ibTransportIndex = 3; + break; + case Transport::IB4: + ibTransportIndex = 4; + break; + case Transport::IB5: + ibTransportIndex = 5; + break; + case Transport::IB6: + ibTransportIndex = 6; + break; + case Transport::IB7: + ibTransportIndex = 7; + break; + default: + throw std::runtime_error("Not an IB transport"); } if (ibTransportIndex >= num) { throw std::runtime_error("IB transport out of range"); @@ -406,30 +418,31 @@ std::string getIBDeviceName(Transport ibTransport) { return devices[ibTransportIndex]->name; } -Transport getIBTransportByDeviceName(const std::string& ibDeviceName) { +Transport getIBTransportByDeviceName(const std::string& ibDeviceName) +{ int num; struct ibv_device** devices = ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { if (ibDeviceName == devices[i]->name) { switch (i) { // TODO: get rid of this ugly switch - case 0: - return Transport::IB0; - case 1: - return Transport::IB1; - case 2: - return Transport::IB2; - case 3: - return Transport::IB3; - case 4: - return Transport::IB4; - case 5: - return Transport::IB5; - case 6: - return Transport::IB6; - case 7: - return Transport::IB7; - default: - throw std::runtime_error("IB device index out of range"); + case 0: + return Transport::IB0; + case 1: + return Transport::IB1; + case 2: + return Transport::IB2; + case 3: + return Transport::IB3; + case 4: + return Transport::IB4; + case 5: + return Transport::IB5; + case 6: + return Transport::IB6; + case 7: + return Transport::IB7; + default: + throw std::runtime_error("IB device index out of range"); } } } diff --git a/src/include/basic_proxy_handler.hpp b/src/include/basic_proxy_handler.hpp index 1c4b3f862..58e419309 100644 --- a/src/include/basic_proxy_handler.hpp +++ b/src/include/basic_proxy_handler.hpp @@ -1,12 +1,12 @@ #ifndef MSCCLPP_BASIC_PROXY_SERVICE_HPP_ #define MSCCLPP_BASIC_PROXY_SERVICE_HPP_ -#include "mscclpp.hpp" #include "communicator.hpp" +#include "mscclpp.hpp" namespace mscclpp { -ProxyHandler makeBasicProxyHandler(Communicator::Impl &comm); +ProxyHandler makeBasicProxyHandler(Communicator::Impl& comm); } diff --git a/src/include/channel.hpp b/src/include/channel.hpp index 10a5f6016..2303a57cb 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -1,8 +1,8 @@ #ifndef MSCCLPP_CHANNEL_HPP_ #define MSCCLPP_CHANNEL_HPP_ -#include "mscclpp.hpp" #include "epoch.hpp" +#include "mscclpp.hpp" #include "proxy.hpp" namespace mscclpp { @@ -18,7 +18,7 @@ const ChannelTriggerType channelTriggerFlag = 0x2; const ChannelTriggerType channelTriggerSync = 0x4; // This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles -// mapping to the actual +// mapping to the actual using BufferHandle = uint32_t; #define MSCCLPP_BITS_SIZE 32 @@ -43,20 +43,32 @@ union ChannelTrigger { uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; uint64_t type : MSCCLPP_BITS_TYPE; uint64_t connId : MSCCLPP_BITS_CONNID; - uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - + MSCCLPP_BITS_TYPE); // ensure 64-bit alignment } fields; #ifdef __CUDACC__ - __device__ ChannelTrigger() {} - __device__ ChannelTrigger(ProxyTrigger value) : value(value) {} - __device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size, int connectionId) { + __device__ ChannelTrigger() + { + } + __device__ ChannelTrigger(ProxyTrigger value) : value(value) + { + } + __device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, + uint64_t srcOffset, uint64_t size, int connectionId) + { value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); - value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) << MSCCLPP_BITS_BUFFER_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset); + value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) + << MSCCLPP_BITS_BUFFER_HANDLE) + + src) + << MSCCLPP_BITS_OFFSET) + + dstOffset); } #endif // __CUDACC__ }; -struct ConnectionEpoch { +struct ConnectionEpoch +{ #ifdef __CUDACC__ __forceinline__ __device__ void wait() { @@ -81,8 +93,10 @@ struct ConnectionEpoch { uint64_t* waitEpochId; }; -class HostConnection { +class HostConnection +{ struct Impl; + public: /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ HostConnection(std::unique_ptr); @@ -103,7 +117,7 @@ class HostConnection { * * Inputs: * index: the index of the handle to get - * + * * Returns: a handle to the buffer */ BufferHandle getLocalBuffer(int index); @@ -118,7 +132,7 @@ class HostConnection { * * Inputs: * index: the index of the handle to get - * + * * Returns: a handle to the buffer on the remote peer */ BufferHandle getRemoteBuffer(int index); @@ -140,19 +154,22 @@ class HostConnection { friend class Communicator; }; -struct DeviceConnection { +struct DeviceConnection +{ DeviceConnection() = default; DeviceConnection(HostConnection& hostConn) - : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), - fifo(hostConn.getDeviceFifo()) {} + : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), fifo(hostConn.getDeviceFifo()) + { + } DeviceConnection(const DeviceConnection& other) = default; DeviceConnection& operator=(DeviceConnection& other) = default; #ifdef __CUDACC__ - __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) + __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, + uint64_t size) { fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value); } @@ -168,10 +185,13 @@ struct DeviceConnection { fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value); } - __forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) + __forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, + uint64_t srcOffset, uint64_t size) { epochIncrement(); - fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId).value); + fifo.push( + ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId) + .value); } __forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) @@ -179,16 +199,20 @@ struct DeviceConnection { putWithSignal(dst, offset, src, offset, size); } - __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size) + __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, + uint64_t srcOffset, uint64_t size) { epochIncrement(); - uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, dstOffset, src, srcOffset, size, connectionId).value); + uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, + dstOffset, src, srcOffset, size, connectionId) + .value); while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) ; } - __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, + uint64_t size) { putWithSignalAndFlush(dst, offset, src, offset, size); } @@ -223,10 +247,12 @@ struct DeviceConnection { DeviceProxyFifo fifo; }; -struct SimpleDeviceConnection { +struct SimpleDeviceConnection +{ SimpleDeviceConnection() = default; - SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) { + SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) + { dst = hostConn.getRemoteBuffer(0); src = hostConn.getLocalBuffer(0); } diff --git a/src/include/checks.hpp b/src/include/checks.hpp index ad985e769..69b222ee1 100644 --- a/src/include/checks.hpp +++ b/src/include/checks.hpp @@ -8,6 +8,7 @@ #define MSCCLPP_CHECKS_HPP_ #include "debug.h" +#include #include #define MSCCLPPTHROW(call) \ @@ -26,4 +27,14 @@ } \ } while (false) +#define CUTHROW(cmd) \ + do { \ + CUresult err = cmd; \ + if (err != CUDA_SUCCESS) { \ + const char* errStr; \ + cuGetErrorString(err, &errStr); \ + throw std::runtime_error(std::string("Cu failure '") + std::string(errStr) + "'"); \ + } \ + } while (false) + #endif diff --git a/src/include/comm.h b/src/include/comm.h index dce724fa6..e6a067d6f 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -9,14 +9,14 @@ #include "ib.hpp" #include "proxy.h" -#include #include +#include #define MAXCONNECTIONS 64 struct mscclppBufferRegistration { - void *data; + void* data; uint64_t size; }; diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index e8e274b92..25fface7f 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -1,19 +1,20 @@ #ifndef MSCCL_COMMUNICATOR_HPP_ #define MSCCL_COMMUNICATOR_HPP_ -#include "mscclpp.hpp" -#include "mscclpp.h" #include "channel.hpp" -#include "proxy.hpp" #include "ib.hpp" -#include +#include "mscclpp.h" +#include "mscclpp.hpp" +#include "proxy.hpp" #include +#include namespace mscclpp { class ConnectionBase; -struct Communicator::Impl { +struct Communicator::Impl +{ mscclppComm_t comm; std::vector> connections; std::unordered_map> ibContexts; diff --git a/src/include/connection.hpp b/src/include/connection.hpp index bd08802c1..f957c8a10 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -1,25 +1,27 @@ #ifndef MSCCLPP_CONNECTION_HPP_ #define MSCCLPP_CONNECTION_HPP_ +#include "communicator.hpp" +#include "ib.hpp" #include "mscclpp.hpp" #include -#include "ib.hpp" -#include "communicator.hpp" namespace mscclpp { // TODO: Add functionality to these classes for Communicator to do connectionSetup -class ConnectionBase : public Connection { +class ConnectionBase : public Connection +{ public: - virtual void startSetup(std::shared_ptr bootstrap) {}; - virtual void endSetup(std::shared_ptr bootstrap) {}; + virtual void startSetup(std::shared_ptr bootstrap){}; + virtual void endSetup(std::shared_ptr bootstrap){}; }; -class CudaIpcConnection : public ConnectionBase { +class CudaIpcConnection : public ConnectionBase +{ cudaStream_t stream; -public: +public: CudaIpcConnection(); ~CudaIpcConnection(); @@ -28,19 +30,21 @@ class CudaIpcConnection : public ConnectionBase { Transport remoteTransport() override; - void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override; void flush() override; }; -class IBConnection : public ConnectionBase { +class IBConnection : public ConnectionBase +{ int remoteRank_; int tag_; Transport transport_; Transport remoteTransport_; IbQp* qp; -public: +public: IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); ~IBConnection(); @@ -49,7 +53,8 @@ class IBConnection : public ConnectionBase { Transport remoteTransport() override; - void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override; void flush() override; diff --git a/src/include/epoch.hpp b/src/include/epoch.hpp index 942edd8b4..fd25b51fe 100644 --- a/src/include/epoch.hpp +++ b/src/include/epoch.hpp @@ -5,7 +5,8 @@ namespace mscclpp { -struct alignas(16) SignalEpochId { +struct alignas(16) SignalEpochId +{ // every signal(), increaments this and either: // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy // 2) gpu thread directly writes it to remoteSignalEpochId->device @@ -14,7 +15,8 @@ struct alignas(16) SignalEpochId { uint64_t proxy; }; -struct DeviceEpoch { +struct DeviceEpoch +{ #ifdef __CUDACC__ __forceinline__ __device__ void wait() { @@ -34,10 +36,11 @@ struct DeviceEpoch { uint64_t* waitEpochId; }; - -class Epoch { +class Epoch +{ struct Impl; std::unique_ptr pimpl; + public: Epoch(); ~Epoch(); diff --git a/src/include/host_connection.hpp b/src/include/host_connection.hpp index 495130d9a..8ac5d9f17 100644 --- a/src/include/host_connection.hpp +++ b/src/include/host_connection.hpp @@ -1,13 +1,14 @@ #ifndef MSCCLPP_HOST_CONNECTION_HPP_ #define MSCCLPP_HOST_CONNECTION_HPP_ -#include "mscclpp.hpp" -#include "mscclpp.h" #include "comm.h" +#include "mscclpp.h" +#include "mscclpp.hpp" namespace mscclpp { -struct HostConnection::Impl { +struct HostConnection::Impl +{ Communicator* comm; mscclppConn* conn; mscclppHostConn_t* hostConn; diff --git a/src/include/ib.hpp b/src/include/ib.hpp index b1baeb757..78d31ce6b 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -1,9 +1,9 @@ #ifndef MSCCLPP_IB_HPP_ #define MSCCLPP_IB_HPP_ +#include #include #include -#include #define MSCCLPP_IB_CQ_SIZE 1024 #define MSCCLPP_IB_CQ_POLL_NUM 1 @@ -55,8 +55,10 @@ class IbQp void rtr(const IbQpInfo& info); void rts(); - int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); - int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + int stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled); + int stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, + uint64_t dstOffset, bool signaled, unsigned int immData); void postSend(); void postRecv(uint64_t wrId); int pollCq(); diff --git a/src/include/mscclpp.h b/src/include/mscclpp.h index c01246abe..4789b80fe 100644 --- a/src/include/mscclpp.h +++ b/src/include/mscclpp.h @@ -191,7 +191,8 @@ struct mscclppHostConn { virtual ~mscclppHostConn() = default; virtual void put(uint64_t dstDataOffset, uint64_t srcDataOffset, uint64_t dataSize) = 0; - virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) = 0; + virtual void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, + uint64_t dataSize) = 0; virtual void signal() = 0; virtual void wait() = 0; virtual void flush() = 0; @@ -232,7 +233,6 @@ typedef enum mscclppNumResults = 8 } mscclppResult_t; - /* Create a unique ID for communication. Only needs to be called by one process. * Use with mscclppCommInitRankFromId(). * All processes need to provide the same ID to mscclppCommInitRankFromId(). @@ -343,7 +343,8 @@ mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, int tag, void * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. */ -mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev = 0); +mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, + mscclppTransport_t transportType, const char* ibDev = 0); /* Register a buffer for use with a connection. * @@ -356,7 +357,8 @@ mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, * Outputs: * handle: a handle to the buffer registration */ -mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle); +mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, + mscclppBufferHandle_t* handle); /* Establish all connections declared by mscclppConnect(). This function must be called after all mscclppConnect() * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 3b9c6d8d5..8a85ebc68 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -6,16 +6,16 @@ #define MSCCLPP_PATCH 0 #define MSCCLPP_VERSION (MSCCLPP_MAJOR * 10000 + MSCCLPP_MINOR * 100 + MSCCLPP_PATCH) -#include +#include #include #include -#include - +#include namespace mscclpp { #define MSCCLPP_UNIQUE_ID_BYTES 128 -struct UniqueId { +struct UniqueId +{ char internal[MSCCLPP_UNIQUE_ID_BYTES]; }; @@ -64,7 +64,8 @@ class Bootstrap : public BaseBootstrap */ std::unique_ptr getUniqueId(); -enum class Transport { +enum class Transport +{ Unknown, CudaIpc, IB0, @@ -79,109 +80,137 @@ enum class Transport { }; namespace detail { - const size_t TransportFlagsSize = 10; - static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), "TransportFlagsSize must match the number of transports"); - using TransportFlagsBase = std::bitset; -} +const size_t TransportFlagsSize = 10; +static_assert(TransportFlagsSize == static_cast(Transport::NumTransports), + "TransportFlagsSize must match the number of transports"); +using TransportFlagsBase = std::bitset; +} // namespace detail -class TransportFlags : private detail::TransportFlagsBase { +class TransportFlags : private detail::TransportFlagsBase +{ public: TransportFlags() = default; - TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast(transport)) {} + TransportFlags(Transport transport) : detail::TransportFlagsBase(1 << static_cast(transport)) + { + } - bool has(Transport transport) const { + bool has(Transport transport) const + { return detail::TransportFlagsBase::test(static_cast(transport)); } - bool none() const { + bool none() const + { return detail::TransportFlagsBase::none(); } - bool any() const { + bool any() const + { return detail::TransportFlagsBase::any(); } - bool all() const { + bool all() const + { return detail::TransportFlagsBase::all(); } - size_t count() const { + size_t count() const + { return detail::TransportFlagsBase::count(); } - TransportFlags& operator|=(TransportFlags other) { + TransportFlags& operator|=(TransportFlags other) + { detail::TransportFlagsBase::operator|=(other); return *this; } - TransportFlags operator|(TransportFlags other) const { + TransportFlags operator|(TransportFlags other) const + { return TransportFlags(*this) |= other; } - TransportFlags operator|(Transport transport) const { + TransportFlags operator|(Transport transport) const + { return *this | TransportFlags(transport); } - TransportFlags& operator&=(TransportFlags other) { + TransportFlags& operator&=(TransportFlags other) + { detail::TransportFlagsBase::operator&=(other); return *this; } - TransportFlags operator&(TransportFlags other) const { + TransportFlags operator&(TransportFlags other) const + { return TransportFlags(*this) &= other; } - TransportFlags operator&(Transport transport) const { + TransportFlags operator&(Transport transport) const + { return *this & TransportFlags(transport); } - TransportFlags& operator^=(TransportFlags other) { + TransportFlags& operator^=(TransportFlags other) + { detail::TransportFlagsBase::operator^=(other); return *this; } - TransportFlags operator^(TransportFlags other) const { + TransportFlags operator^(TransportFlags other) const + { return TransportFlags(*this) ^= other; } - TransportFlags operator^(Transport transport) const { + TransportFlags operator^(Transport transport) const + { return *this ^ TransportFlags(transport); } - TransportFlags operator~() const { + TransportFlags operator~() const + { return TransportFlags(*this).flip(); } - bool operator==(TransportFlags other) const { + bool operator==(TransportFlags other) const + { return detail::TransportFlagsBase::operator==(other); } - bool operator!=(TransportFlags other) const { + bool operator!=(TransportFlags other) const + { return detail::TransportFlagsBase::operator!=(other); } - detail::TransportFlagsBase toBitset() const { + detail::TransportFlagsBase toBitset() const + { return *this; } private: - TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) {} + TransportFlags(detail::TransportFlagsBase bitset) : detail::TransportFlagsBase(bitset) + { + } }; -inline TransportFlags operator|(Transport transport1, Transport transport2) { +inline TransportFlags operator|(Transport transport1, Transport transport2) +{ return TransportFlags(transport1) | transport2; } -inline TransportFlags operator&(Transport transport1, Transport transport2) { +inline TransportFlags operator&(Transport transport1, Transport transport2) +{ return TransportFlags(transport1) & transport2; } -inline TransportFlags operator^(Transport transport1, Transport transport2) { +inline TransportFlags operator^(Transport transport1, Transport transport2) +{ return TransportFlags(transport1) ^ transport2; } const TransportFlags NoTransports = TransportFlags(); -const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7; +const TransportFlags AllIBTransports = Transport::IB0 | Transport::IB1 | Transport::IB2 | Transport::IB3 | + Transport::IB4 | Transport::IB5 | Transport::IB6 | Transport::IB7; const TransportFlags AllTransports = AllIBTransports | Transport::CudaIpc; int getIBDeviceCount(); @@ -191,11 +220,12 @@ Transport getIBTransportByDeviceName(const std::string& ibDeviceName); class Communicator; class Connection; -class RegisteredMemory { +class RegisteredMemory +{ struct Impl; std::shared_ptr pimpl; -public: +public: RegisteredMemory(std::shared_ptr pimpl); ~RegisteredMemory(); @@ -211,9 +241,11 @@ class RegisteredMemory { friend class Communicator; }; -class Connection { +class Connection +{ public: - virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; + virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) = 0; virtual void flush() = 0; @@ -225,24 +257,24 @@ class Connection { static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); }; -class Communicator { +class Communicator +{ public: /* Initialize the communicator. - * - * Inputs: - * bootstrap: an implementation of the of BaseBootstrap that the communicator will use - */ + * + * Inputs: + * bootstrap: an implementation of the of BaseBootstrap that the communicator will use + */ Communicator(std::shared_ptr bootstrap); - ~Communicator(); - + /* Ring-based AllGather through the bootstrap socket. - * - * Inputs: - * data: data array to be gathered where `[r*size, (r+1)*size)` is the data for rank `r` - * size: data size per rank - */ + * + * Inputs: + * data: data array to be gathered where `[r*size, (r+1)*size)` is the data for rank `r` + * size: data size per rank + */ void bootstrapAllGather(void* data, int size); /* A no-op function that is used to synchronize all processes via a bootstrap allgather*/ @@ -253,33 +285,34 @@ class Communicator { * Inputs: * data: base pointer to the memory * size: size of the memory region in bytes - * + * * Returns: a handle to the buffer */ RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); /* Connect to a remote rank. This function only prepares metadata for connection. The actual connection - * is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection - * from rank i to remote rank j needs to have a counterpart from rank j to rank i. - * Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages - * and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has - * security risks if the devConn's accesses are given to a malicious process. - * - * Inputs: - * remoteRank: the rank of the remote process - * tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be - * used to identify the connection inside a GPU kernel. - * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) - * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. - */ + * is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection + * from rank i to remote rank j needs to have a counterpart from rank j to rank i. + * Note that with IB, buffers are registered at a page level and if a buffer is spread through multiple pages + * and do not fully utilize all of them, IB's QP has to register for all involved pages. This potentially has + * security risks if the devConn's accesses are given to a malicious process. + * + * Inputs: + * remoteRank: the rank of the remote process + * tag: the tag of the connection. tag is copied into the corresponding mscclppDevConn_t, which can be + * used to identify the connection inside a GPU kernel. + * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) + * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. + */ std::shared_ptr connect(int remoteRank, int tag, Transport transport); /* Establish all connections declared by connect(). This function must be called after all connect() - * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. - */ + * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. + */ void connectionSetup(); struct Impl; + private: std::unique_ptr pimpl; }; @@ -287,12 +320,13 @@ class Communicator { } // namespace mscclpp namespace std { - template <> - struct hash { - size_t operator()(const mscclpp::TransportFlags& flags) const { - return hash()(flags.toBitset()); - } - }; -} +template <> struct hash +{ + size_t operator()(const mscclpp::TransportFlags& flags) const + { + return hash()(flags.toBitset()); + } +}; +} // namespace std #endif // MSCCLPP_H_ diff --git a/src/include/mscclppfifo.hpp b/src/include/mscclppfifo.hpp index b5f8ba4c8..7e2820b00 100644 --- a/src/include/mscclppfifo.hpp +++ b/src/include/mscclppfifo.hpp @@ -1,13 +1,14 @@ #ifndef MSCCLPPFIFO_HPP_ #define MSCCLPPFIFO_HPP_ -#include #include #include +#include namespace mscclpp { -struct alignas(16) ProxyTrigger { +struct alignas(16) ProxyTrigger +{ uint64_t fst, snd; }; @@ -24,7 +25,8 @@ struct alignas(16) ProxyTrigger { * Why duplicating the tail is a good idea? The fifo is large engouh and we do not need frequent updates * for the tail as there is usually enough space for device threads to push their work into. */ -struct DeviceProxyFifo { +struct DeviceProxyFifo +{ #ifdef __CUDACC__ __forceinline__ __device__ uint64_t push(ProxyTrigger trigger) { @@ -34,29 +36,28 @@ struct DeviceProxyFifo { while (*(volatile uint64_t*)&this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0) ; ProxyTrigger* triggerPtr = (ProxyTrigger*)&(this->triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE]); - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), - "l"(trigger.fst), "l"(trigger.snd)); + asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); return curFifoHead; } #endif // __CUDACC__ ProxyTrigger* triggers; // Allocate on host via cudaHostAlloc. This space is used for pushing the workelements - uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused - // occasionally to device - uint64_t* head; // Allocated on device. Only accessed by device + uint64_t* tailReplica; // Allocated on device. proxyState->fifoTailHost is the true tail on host and pused + // occasionally to device + uint64_t* head; // Allocated on device. Only accessed by device }; class HostProxyFifo { public: HostProxyFifo(); - + ~HostProxyFifo(); - void poll(ProxyTrigger *trigger); - + void poll(ProxyTrigger* trigger); + void pop(); - + void flushTail(bool sync = false); DeviceProxyFifo toDevice(); diff --git a/src/include/proxy.h b/src/include/proxy.h index 3746806b7..5bcb7da54 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -60,7 +60,7 @@ struct mscclppProxyState int numaNodeToBind; mscclpp::IbCtx* ibContext; // For IB connection only - cudaStream_t p2pStream; // for P2P DMA engine only + cudaStream_t p2pStream; // for P2P DMA engine only struct mscclppProxyFifo fifo; }; diff --git a/src/include/proxy.hpp b/src/include/proxy.hpp index 70b6ba493..ac4116b31 100644 --- a/src/include/proxy.hpp +++ b/src/include/proxy.hpp @@ -3,12 +3,13 @@ #include -#include #include +#include namespace mscclpp { -enum class ProxyHandlerResult { +enum class ProxyHandlerResult +{ Continue, FlushFifoTailAndContinue, Stop, @@ -17,7 +18,8 @@ enum class ProxyHandlerResult { class Proxy; using ProxyHandler = std::function; -class Proxy { +class Proxy +{ public: Proxy(ProxyHandler handler); @@ -26,7 +28,7 @@ class Proxy { void start(); void stop(); - + HostProxyFifo& fifo(); private: diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index afe42da45..1c37ff04b 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -1,15 +1,16 @@ #ifndef MSCCLPP_REGISTERED_MEMORY_HPP_ #define MSCCLPP_REGISTERED_MEMORY_HPP_ -#include "mscclpp.hpp" -#include "mscclpp.h" -#include "ib.hpp" #include "communicator.hpp" +#include "ib.hpp" +#include "mscclpp.h" +#include "mscclpp.hpp" #include namespace mscclpp { -struct TransportInfo { +struct TransportInfo +{ Transport transport; // TODO: rewrite this using std::variant or something @@ -21,7 +22,8 @@ struct TransportInfo { }; }; -struct RegisteredMemory::Impl { +struct RegisteredMemory::Impl +{ void* data; size_t size; int rank; @@ -31,7 +33,8 @@ struct RegisteredMemory::Impl { Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl); Impl(const std::vector& data); - TransportInfo& getTransportInfo(Transport transport) { + TransportInfo& getTransportInfo(Transport transport) + { for (auto& entry : transportInfos) { if (entry.transport == transport) { return entry; diff --git a/src/include/registered_ptr.hpp b/src/include/registered_ptr.hpp index 7eadb6b0f..4f03ea40a 100644 --- a/src/include/registered_ptr.hpp +++ b/src/include/registered_ptr.hpp @@ -3,32 +3,44 @@ namespace mscclpp { -template -class RegisteredPtr { +template class RegisteredPtr +{ RegisteredMemory memory; size_t offset; + public: - RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset) {} - RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0) {} - ~RegisteredPtr() {} + RegisteredPtr(RegisteredMemory memory, size_t offset) : memory(memory), offset(offset) + { + } + RegisteredPtr(RegisteredMemory memory) : RegisteredPtr(memory, 0) + { + } + ~RegisteredPtr() + { + } - RegisteredMemory memory() { + RegisteredMemory memory() + { return memory; } - T* data() { + T* data() + { return reinterpret_cast(memory.data()); } - size_t size() { + size_t size() + { return memory.size() / sizeof(T); } - size_t offset() { + size_t offset() + { return offset; } - RegisteredPtr operator+(size_t offset) { + RegisteredPtr operator+(size_t offset) + { return RegisteredPtr(memory, this->offset + offset); } diff --git a/src/init.cc b/src/init.cc index c5b6a66b6..03f037c48 100644 --- a/src/init.cc +++ b/src/init.cc @@ -6,8 +6,8 @@ #if defined(MSCCLPP_USE_GDRCOPY) #include "gdr.h" #endif -#include "mscclpp.h" #include "infiniband/verbs.h" +#include "mscclpp.h" #include #include #include @@ -327,7 +327,8 @@ struct mscclppHostP2PConn : mscclppHostConn { put(1, dstDataOffset, 1, srcDataOffset, dataSize); } - void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) + void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, + uint64_t dataSize) { void* srcBuff = (void*)((char*)conn->bufferRegistrations[src].data + srcDataOffset); void* dstBuff = (void*)((char*)conn->remoteBufferRegistrations[dst].data + dstDataOffset); @@ -365,7 +366,8 @@ struct mscclppHostIBConn : mscclppHostConn { put(1, dstDataOffset, 1, srcDataOffset, dataSize); } - void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, uint64_t dataSize) + void put(mscclppBufferHandle_t dst, uint64_t dstDataOffset, mscclppBufferHandle_t src, uint64_t srcDataOffset, + uint64_t dataSize) { this->ibQp->stageSend(this->ibMrs[src], this->remoteIbMrInfos[dst], (uint32_t)dataSize, /*wrId=*/0, /*srcOffset=*/srcDataOffset, /*dstOffset=*/dstDataOffset, /*signaled=*/false); @@ -413,7 +415,8 @@ struct mscclppHostIBConn : mscclppHostConn std::vector remoteIbMrInfos; }; -MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, mscclppTransport_t transportType, const char* ibDev) +MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int remoteRank, int tag, + mscclppTransport_t transportType, const char* ibDev) { // save this processes numa binding and set it to the one closest to the device // so that all the allocation are close to the device @@ -550,7 +553,8 @@ MSCCLPP_API mscclppResult_t mscclppConnectWithoutBuffer(mscclppComm_t comm, int MSCCLPPCHECK(setNumaState(curProcessState)); mscclppBufferHandle_t signalHandle = -1; - MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, sizeof(mscclppDevConnSignalEpochId), &signalHandle)); + MSCCLPPCHECK(mscclppRegisterBufferForConnection(comm, connId, conn->devConn->localSignalEpochId, + sizeof(mscclppDevConnSignalEpochId), &signalHandle)); if (signalHandle != 0) { WARN("signal handle should be 0"); return mscclppInternalError; @@ -579,7 +583,9 @@ MSCCLPP_API mscclppResult_t mscclppConnect(mscclppComm_t comm, int remoteRank, i return mscclppSuccess; } -MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, uint64_t buffSize, mscclppBufferHandle_t *handle) { +MSCCLPP_API mscclppResult_t mscclppRegisterBufferForConnection(mscclppComm_t comm, int connIdx, void* localBuff, + uint64_t buffSize, mscclppBufferHandle_t* handle) +{ if (connIdx >= comm->nConns) { WARN("connIdx out of range"); return mscclppInvalidArgument; @@ -605,26 +611,31 @@ struct connInfo mscclpp::IbQpInfo infoQp; std::vector bufferInfos; - struct header { + struct header + { mscclpp::IbQpInfo infoQp; int numBufferInfos; }; - mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag) { + mscclppResult_t sendOverBootstrap(void* bootstrap, int remoteRank, int tag) + { header h; h.infoQp = infoQp; h.numBufferInfos = bufferInfos.size(); MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, &h, sizeof(header))); - MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo))); + MSCCLPPCHECK(bootstrapSend(bootstrap, remoteRank, tag, bufferInfos.data(), + bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo))); return mscclppSuccess; } - mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) { + mscclppResult_t recvOverBootstrap(void* bootstrap, int remoteRank, int tag) + { header h; MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, &h, sizeof(header))); infoQp = h.infoQp; bufferInfos.resize(h.numBufferInfos); - MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo))); + MSCCLPPCHECK(bootstrapRecv(bootstrap, remoteRank, tag, bufferInfos.data(), + bufferInfos.size() * sizeof(mscclppBufferRegistrationInfo))); return mscclppSuccess; } }; @@ -637,7 +648,7 @@ mscclppResult_t mscclppP2pConnectionSetupStart(struct connInfo* connInfo /*input } // Add all registered buffers - for (const auto &bufReg : conn->bufferRegistrations) { + for (const auto& bufReg : conn->bufferRegistrations) { connInfo->bufferInfos.emplace_back(); CUDACHECK(cudaIpcGetMemHandle(&connInfo->bufferInfos.back().cudaHandle, bufReg.data)); connInfo->bufferInfos.back().size = bufReg.size; @@ -659,7 +670,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/ // Open all remote registered buffers for (size_t i = 0; i < connInfo->bufferInfos.size(); i++) { mscclppBufferRegistration newBufReg; - CUDACHECK(cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess)); + CUDACHECK( + cudaIpcOpenMemHandle(&newBufReg.data, connInfo->bufferInfos[i].cudaHandle, cudaIpcMemLazyEnablePeerAccess)); newBufReg.size = connInfo->bufferInfos[i].size; conn->remoteBufferRegistrations.push_back(newBufReg); } @@ -670,8 +682,8 @@ mscclppResult_t mscclppP2pConnectionSetupEnd(struct connInfo* connInfo /*input*/ } conn->devConn->remoteSignalEpochId = (mscclppDevConnSignalEpochId*)conn->remoteBufferRegistrations[0].data; - // For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote buffer - // to the first remote data buffer + // For backwards compatibility with the previous API that assumed one data buffer per connection, set the remote + // buffer to the first remote data buffer if (conn->remoteBufferRegistrations.size() > 1) { conn->devConn->remoteBuff = conn->remoteBufferRegistrations[1].data; } @@ -695,7 +707,7 @@ mscclppResult_t mscclppIbConnectionSetupStart(struct connInfo* connInfo /*output } // Add all registered buffers - for (const auto &bufReg : conn->bufferRegistrations) { + for (const auto& bufReg : conn->bufferRegistrations) { hostConn->ibMrs.emplace_back(ibCtx->registerMr(bufReg.data, sizeof(struct mscclppDevConnSignalEpochId))); connInfo->bufferInfos.emplace_back(); connInfo->bufferInfos.back().ibMrInfo = hostConn->ibMrs.back()->getInfo(); @@ -743,7 +755,8 @@ MSCCLPP_API mscclppResult_t mscclppConnectionSetup(mscclppComm_t comm) MSCCLPPCHECK(mscclppIbConnectionSetupStart(&cInfo, conn)); } // TODO: from saemal: do we possibly deadlock if there are too many outstanding sends? - // MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, sizeof(cInfo))); + // MSCCLPPCHECK(bootstrapSend(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag, &cInfo, + // sizeof(cInfo))); MSCCLPPCHECK(cInfo.sendOverBootstrap(comm->bootstrap, conn->devConn->remoteRank, conn->devConn->tag)); } diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc index 2d1cf0987..b55d6995f 100644 --- a/src/proxy_cpp.cc +++ b/src/proxy_cpp.cc @@ -1,8 +1,8 @@ +#include "api.h" #include "mscclpp.hpp" #include "utils.h" -#include "api.h" -#include #include +#include namespace mscclpp { @@ -10,26 +10,32 @@ const int ProxyStopCheckPeriod = 1000; const int ProxyFlushPeriod = 4; -struct Proxy::Impl { +struct Proxy::Impl +{ ProxyHandler handler; HostProxyFifo fifo; std::thread service; std::atomic_bool running; - Impl(ProxyHandler handler) : handler(handler), running(false) {} + Impl(ProxyHandler handler) : handler(handler), running(false) + { + } }; -MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) { +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) +{ pimpl = std::make_unique(handler); } -MSCCLPP_API_CPP Proxy::~Proxy() { +MSCCLPP_API_CPP Proxy::~Proxy() +{ if (pimpl) { stop(); } } -MSCCLPP_API_CPP void Proxy::start() { +MSCCLPP_API_CPP void Proxy::start() +{ pimpl->running = true; pimpl->service = std::thread([this] { // from this point on, proxy thread will stay close to the device @@ -52,7 +58,7 @@ MSCCLPP_API_CPP void Proxy::start() { // Poll to see if we are ready to send anything fifo.poll(&trigger); if (trigger.fst == 0) { // TODO: this check is a potential pitfall for custom triggers - continue; // there is one in progress + continue; // there is one in progress } ProxyHandlerResult result = handler(trigger); @@ -83,14 +89,16 @@ MSCCLPP_API_CPP void Proxy::start() { }); } -MSCCLPP_API_CPP void Proxy::stop() { +MSCCLPP_API_CPP void Proxy::stop() +{ pimpl->running = false; if (pimpl->service.joinable()) { pimpl->service.join(); } } -MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() { +MSCCLPP_API_CPP HostProxyFifo& Proxy::fifo() +{ return pimpl->fifo; } diff --git a/src/registered_memory.cc b/src/registered_memory.cc index b26ea2d54..b9769dc96 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -1,10 +1,13 @@ #include "registered_memory.hpp" #include "checks.hpp" #include +#include namespace mscclpp { -RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) : data(data), size(size), rank(rank), transports(transports) { +RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) + : data(data), size(size), rank(rank), transports(transports) +{ if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; transportInfo.transport = Transport::CudaIpc; @@ -23,38 +26,53 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t transportInfo.ibLocal = true; this->transportInfos.push_back(transportInfo); }; - if (transports.has(Transport::IB0)) addIb(Transport::IB0); - if (transports.has(Transport::IB1)) addIb(Transport::IB1); - if (transports.has(Transport::IB2)) addIb(Transport::IB2); - if (transports.has(Transport::IB3)) addIb(Transport::IB3); - if (transports.has(Transport::IB4)) addIb(Transport::IB4); - if (transports.has(Transport::IB5)) addIb(Transport::IB5); - if (transports.has(Transport::IB6)) addIb(Transport::IB6); - if (transports.has(Transport::IB7)) addIb(Transport::IB7); + if (transports.has(Transport::IB0)) + addIb(Transport::IB0); + if (transports.has(Transport::IB1)) + addIb(Transport::IB1); + if (transports.has(Transport::IB2)) + addIb(Transport::IB2); + if (transports.has(Transport::IB3)) + addIb(Transport::IB3); + if (transports.has(Transport::IB4)) + addIb(Transport::IB4); + if (transports.has(Transport::IB5)) + addIb(Transport::IB5); + if (transports.has(Transport::IB6)) + addIb(Transport::IB6); + if (transports.has(Transport::IB7)) + addIb(Transport::IB7); } } -RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) {} +RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) +{ +} RegisteredMemory::~RegisteredMemory() = default; -void* RegisteredMemory::data() { +void* RegisteredMemory::data() +{ return pimpl->data; } -size_t RegisteredMemory::size() { +size_t RegisteredMemory::size() +{ return pimpl->size; } -int RegisteredMemory::rank() { +int RegisteredMemory::rank() +{ return pimpl->rank; } -TransportFlags RegisteredMemory::transports() { +TransportFlags RegisteredMemory::transports() +{ return pimpl->transports; } -std::vector RegisteredMemory::serialize() { +std::vector RegisteredMemory::serialize() +{ std::vector result; std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); @@ -67,7 +85,8 @@ std::vector RegisteredMemory::serialize() { for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); if (entry.transport == Transport::CudaIpc) { - std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), + std::back_inserter(result)); } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); } else { @@ -77,11 +96,13 @@ std::vector RegisteredMemory::serialize() { return result; } -RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { +RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) +{ return RegisteredMemory(std::make_shared(data)); } -RegisteredMemory::Impl::Impl(const std::vector& serialization) { +RegisteredMemory::Impl::Impl(const std::vector& serialization) +{ auto it = serialization.begin(); std::copy_n(it, sizeof(this->size), reinterpret_cast(&this->size)); it += sizeof(this->size); @@ -118,6 +139,9 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { if (transports.has(Transport::CudaIpc)) { auto entry = getTransportInfo(Transport::CudaIpc); + void* baseDataPtr; + size_t baseDataSize; // dummy + CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); } } diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index 9b056e846..908a24f4c 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -4,14 +4,14 @@ #ifdef MSCCLPP_USE_MPI_FOR_TESTS #include "mpi.h" #endif // MSCCLPP_USE_MPI_FOR_TESTS +#include +#include #include #include #include #include #include #include -#include -#include static int nranksPerNode = 8; @@ -50,7 +50,8 @@ static double getTime(void) __constant__ mscclpp::SimpleDeviceConnection constDevConns[16]; -__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) +__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, + size_t nelemsPerGPU) { // this allgather is really simple and implemented as an alltoall @@ -69,8 +70,8 @@ __device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, in devConn.wait(); } -__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, - uint64_t offset, uint64_t size) +__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, + int remoteRank, uint64_t offset, uint64_t size) { // this allgather algorithm works as follows: // Step 1: GPU rank i sends data to GPU rank (i+1) % nranksPerNode @@ -93,15 +94,15 @@ __device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank } } -__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, - size_t nelemsPerGPU) +__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, + int remoteRank, size_t nelemsPerGPU) { localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); } -__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, - size_t nelemsPerGPU) +__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, + int remoteRank, size_t nelemsPerGPU) { // this allgather is a pipelined and hierarchical one and only works for two nodes // it is implemented as follows: @@ -243,13 +244,13 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co comm.connectionSetup(); std::vector devConns; - std::transform(hostConns.begin(), hostConns.end(), std::back_inserter(devConns), - [](std::shared_ptr& hostConn) { - return mscclpp::SimpleDeviceConnection(*hostConn); - }); + std::transform( + hostConns.begin(), hostConns.end(), std::back_inserter(devConns), + [](std::shared_ptr& hostConn) { return mscclpp::SimpleDeviceConnection(*hostConn); }); assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection)); - CUDACHECK(cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size() )); + CUDACHECK( + cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size())); } void printUsage(const char* prog, bool isMpi) @@ -399,17 +400,17 @@ int main(int argc, const char* argv[]) } size_t nelemsPerGPU = dataSize / sizeof(int) / world_size; - try{ + try { if (rank == 0) - printf("Initializing MSCCL++\n"); + printf("Initializing MSCCL++\n"); mscclpp::Communicator comm(world_size, ip_port, rank); if (rank == 0) - printf("Initializing data for allgather test\n"); + printf("Initializing data for allgather test\n"); initializeAndAllocateAllGatherData(rank, world_size, dataSize, nelemsPerGPU, &data_h, &data_d); if (rank == 0) - printf("Setting up the connection in MSCCL++\n"); + printf("Setting up the connection in MSCCL++\n"); setupMscclppConnections(rank, world_size, comm, data_d, dataSize); if (rank == 0) @@ -466,7 +467,7 @@ int main(int argc, const char* argv[]) int cudagraphwarmup = 10; if (rank == 0) printf("Warming up %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphwarmup, - cudagraphiter); + cudagraphiter); for (int i = 0; i < cudagraphwarmup; ++i) { cudaGraphLaunch(instance, stream); } @@ -476,7 +477,7 @@ int main(int argc, const char* argv[]) int cudagraphlaunch = 10; if (rank == 0) printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch, - cudagraphiter); + cudagraphiter); comm.bootstrapAllGather(tmp, sizeof(int)); double t0, t1, ms, time_in_us; t0 = getTime(); @@ -489,7 +490,7 @@ int main(int argc, const char* argv[]) ms = (t1 - t0) * 1000.0; time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter; printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, - (double)(dataSize) / 1e9 / (time_in_us / 1e6)); + (double)(dataSize) / 1e9 / (time_in_us / 1e6)); comm.bootstrapAllGather(tmp, sizeof(int)); if (rank == 0) diff --git a/tests/bootstrap_test_cpp.cc b/tests/bootstrap_test_cpp.cc index bdde84673..e4fe65bb7 100644 --- a/tests/bootstrap_test_cpp.cc +++ b/tests/bootstrap_test_cpp.cc @@ -1,11 +1,12 @@ #include "mscclpp.hpp" -#include #include #include +#include #include -void test_allgather(std::shared_ptr bootstrap){ +void test_allgather(std::shared_ptr bootstrap) +{ std::vector tmp(bootstrap->getNranks(), 0); tmp[bootstrap->getRank()] = bootstrap->getRank() + 1; bootstrap->allGather(tmp.data(), sizeof(int)); @@ -16,13 +17,15 @@ void test_allgather(std::shared_ptr bootstrap){ std::cout << "AllGather test passed!" << std::endl; } -void test_barrier(std::shared_ptr bootstrap){ +void test_barrier(std::shared_ptr bootstrap) +{ bootstrap->barrier(); if (bootstrap->getRank() == 0) std::cout << "Barrier test passed!" << std::endl; } -void test_sendrecv(std::shared_ptr bootstrap){ +void test_sendrecv(std::shared_ptr bootstrap) +{ for (int i = 0; i < bootstrap->getNranks(); i++) { if (bootstrap->getRank() == i) continue; @@ -52,13 +55,15 @@ void test_sendrecv(std::shared_ptr bootstrap){ std::cout << "Send/Recv test passed!" << std::endl; } -void test_all(std::shared_ptr bootstrap){ +void test_all(std::shared_ptr bootstrap) +{ test_allgather(bootstrap); test_barrier(bootstrap); test_sendrecv(bootstrap); } -void test_mscclpp_bootstrap_with_id(int rank, int worldSize){ +void test_mscclpp_bootstrap_with_id(int rank, int worldSize) +{ auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; if (bootstrap->getRank() == 0) @@ -71,7 +76,8 @@ void test_mscclpp_bootstrap_with_id(int rank, int worldSize){ std::cout << "--- MSCCLPP::Bootstrap test with unique id passed! ---" << std::endl; } -void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar){ +void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipPortPiar) +{ std::shared_ptr bootstrap(new mscclpp::Bootstrap(rank, worldSize)); bootstrap->initialize(ipPortPiar); @@ -80,47 +86,57 @@ void test_mscclpp_bootstrap_with_ip_port_pair(int rank, int worldSize, char* ipP std::cout << "--- MSCCLPP::Bootstrap test with ip_port pair passed! ---" << std::endl; } -class MPIBootstrap : public mscclpp::BaseBootstrap { +class MPIBootstrap : public mscclpp::BaseBootstrap +{ public: - MPIBootstrap() : BaseBootstrap() {} - int getRank() override { + MPIBootstrap() : BaseBootstrap() + { + } + int getRank() override + { int rank; MPI_Comm_rank(MPI_COMM_WORLD, &rank); return rank; } - int getNranks() override { + int getNranks() override + { int worldSize; MPI_Comm_size(MPI_COMM_WORLD, &worldSize); return worldSize; } - void allGather(void *sendbuf, int size) override { + void allGather(void* sendbuf, int size) override + { MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD); } - void barrier() override { + void barrier() override + { MPI_Barrier(MPI_COMM_WORLD); } - void send(void *sendbuf, int size, int dest, int tag) override { + void send(void* sendbuf, int size, int dest, int tag) override + { MPI_Send(sendbuf, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD); } - void recv(void *recvbuf, int size, int source, int tag) override { + void recv(void* recvbuf, int size, int source, int tag) override + { MPI_Recv(recvbuf, size, MPI_BYTE, source, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } }; -void test_mpi_bootstrap(){ +void test_mpi_bootstrap() +{ std::shared_ptr bootstrap(new MPIBootstrap()); test_all(bootstrap); if (bootstrap->getRank() == 0) std::cout << "--- MPI Bootstrap test passed! ---" << std::endl; } -int main(int argc, char **argv) +int main(int argc, char** argv) { int rank, worldSize; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - if (argc > 2){ + if (argc > 2) { if (rank == 0) std::cout << "Usage: " << argv[0] << " [ip:port]" << std::endl; MPI_Finalize(); diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 1f14ca792..6864d97b5 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -1,25 +1,20 @@ #include "mscclpp.hpp" -#include #include #include +#include #include -mscclpp::Transport findIb(int localRank){ - mscclpp::Transport IBs[] = { - mscclpp::Transport::IB0, - mscclpp::Transport::IB1, - mscclpp::Transport::IB2, - mscclpp::Transport::IB3, - mscclpp::Transport::IB4, - mscclpp::Transport::IB5, - mscclpp::Transport::IB6, - mscclpp::Transport::IB7 - }; +mscclpp::Transport findIb(int localRank) +{ + mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, + mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, + mscclpp::Transport::IB6, mscclpp::Transport::IB7}; return IBs[localRank]; } -void test_communicator(int rank, int worldSize, int nranksPerNode){ +void test_communicator(int rank, int worldSize, int nranksPerNode) +{ auto bootstrap = std::make_shared(rank, worldSize); mscclpp::UniqueId id; if (bootstrap->getRank() == 0) @@ -28,9 +23,9 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ bootstrap->initialize(id); auto communicator = std::make_shared(bootstrap); - for (int i = 0; i < worldSize; i++){ - if (i != rank){ - if (i / nranksPerNode == rank / nranksPerNode){ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + if (i / nranksPerNode == rank / nranksPerNode) { auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); @@ -43,8 +38,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode){ std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; } - -int main(int argc, char **argv) +int main(int argc, char** argv) { int rank, worldSize; MPI_Init(&argc, &argv); @@ -56,7 +50,7 @@ int main(int argc, char **argv) MPI_Comm_size(shmcomm, &shmWorldSize); int nranksPerNode = shmWorldSize; MPI_Comm_free(&shmcomm); - + test_communicator(rank, worldSize, nranksPerNode); MPI_Finalize(); diff --git a/tests/unittests/ib_test.cc b/tests/unittests/ib_test.cc index 6f84398f6..3d99acb2c 100644 --- a/tests/unittests/ib_test.cc +++ b/tests/unittests/ib_test.cc @@ -3,8 +3,8 @@ #include "ib.hpp" #include "infiniband/verbs.h" #include "mscclpp.hpp" -#include #include +#include // Measure current time in second. static double getTime(void) From 2ead25d8ebab548301e7c66d9283a1d55932750a Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 21:36:13 +0000 Subject: [PATCH 22/54] INFO for IPC handle opened --- src/registered_memory.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/registered_memory.cc b/src/registered_memory.cc index b9769dc96..52cbb2901 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -142,7 +142,8 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) void* baseDataPtr; size_t baseDataSize; // dummy CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); - CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + CUDATHROW(cudaIpcOpenMemHandle(&baseDataPtr, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); } } From cbfc21851d14185734dc2d4125e42ac660a7c184 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 22:25:03 +0000 Subject: [PATCH 23/54] registered buffer test --- src/communicator.cc | 2 +- src/registered_memory.cc | 5 +++-- tests/communicator_test_cpp.cc | 26 +++++++++++++++++++++++++- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 78df252d0..79e45f8db 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -58,7 +58,7 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() mscclppBootstrapBarrier(pimpl->comm); } -RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) +MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { return RegisteredMemory(std::make_shared(ptr, size, pimpl->comm->rank, transports, *pimpl)); } diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 52cbb2901..42a03a8e1 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -1,4 +1,5 @@ #include "registered_memory.hpp" +#include "api.h" #include "checks.hpp" #include #include @@ -45,11 +46,11 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t } } -RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) +MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) : pimpl(pimpl) { } -RegisteredMemory::~RegisteredMemory() = default; +MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; void* RegisteredMemory::data() { diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 6864d97b5..a05c89811 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -1,10 +1,19 @@ #include "mscclpp.hpp" #include +#include #include #include #include +#define CUDATHROW(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \ + } \ + } while (false) + mscclpp::Transport findIb(int localRank) { mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, @@ -23,17 +32,32 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) bootstrap->initialize(id); auto communicator = std::make_shared(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "Communicator initialization passed" << std::endl; + + auto myIbDevice = findIb(rank % nranksPerNode); for (int i = 0; i < worldSize; i++) { if (i != rank) { if (i / nranksPerNode == rank / nranksPerNode) { auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { - auto connect = communicator->connect(i, 0, findIb(rank % nranksPerNode)); + auto connect = communicator->connect(i, 0, myIbDevice); } } } communicator->connectionSetup(); + if (bootstrap->getRank() == 0) + std::cout << "Connection setup passed" << std::endl; + + int* devicePtr; + int size = 1024; + CUDATHROW(cudaMalloc(&devicePtr, size)); + auto registeredMemory = communicator->registerMemory(devicePtr, size, mscclpp::Transport::CudaIpc | myIbDevice); + + if (bootstrap->getRank() == 0) + std::cout << "Memory registeration passed" << std::endl; + if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; } From 962e63b11abf207e41a2a2d57fcb4f2d330f054a Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Thu, 27 Apr 2023 23:57:51 +0000 Subject: [PATCH 24/54] deserializing registered memory is failing -- commented out --- src/registered_memory.cc | 16 ++++++++-------- tests/communicator_test_cpp.cc | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 42a03a8e1..516a4c64c 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -13,8 +13,11 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t TransportInfo transportInfo; transportInfo.transport = Transport::CudaIpc; cudaIpcMemHandle_t handle; - // TODO: translate data to a base pointer - CUDATHROW(cudaIpcGetMemHandle(&handle, data)); + + void* baseDataPtr; + size_t baseDataSize; // dummy + CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); + CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); transportInfo.cudaIpcHandle = handle; this->transportInfos.push_back(transportInfo); } @@ -72,7 +75,7 @@ TransportFlags RegisteredMemory::transports() return pimpl->transports; } -std::vector RegisteredMemory::serialize() +MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() { std::vector result; std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); @@ -97,7 +100,7 @@ std::vector RegisteredMemory::serialize() return result; } -RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) +MSCCLPP_API_CPP RegisteredMemory RegisteredMemory::deserialize(const std::vector& data) { return RegisteredMemory(std::make_shared(data)); } @@ -140,10 +143,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) if (transports.has(Transport::CudaIpc)) { auto entry = getTransportInfo(Transport::CudaIpc); - void* baseDataPtr; - size_t baseDataSize; // dummy - CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); - CUDATHROW(cudaIpcOpenMemHandle(&baseDataPtr, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); } } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index a05c89811..7fccf57bd 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -55,6 +55,25 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) CUDATHROW(cudaMalloc(&devicePtr, size)); auto registeredMemory = communicator->registerMemory(devicePtr, size, mscclpp::Transport::CudaIpc | myIbDevice); + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + auto serialized = registeredMemory.serialize(); + int serializedSize = serialized.size(); + bootstrap->send(&serializedSize, sizeof(int), i, 0); + bootstrap->send(serialized.data(), serializedSize, i, 1); + } + } + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + int deserializedSize; + bootstrap->recv(&deserializedSize, sizeof(int), i, 0); + std::vector deserialized(deserializedSize); + bootstrap->recv(deserialized.data(), deserializedSize, i, 1); + // auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); + } + } + + if (bootstrap->getRank() == 0) std::cout << "Memory registeration passed" << std::endl; From fa0fcb470e8e7910d3a1a2fedf33d4a4f1afdaee Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Fri, 28 Apr 2023 00:30:07 +0000 Subject: [PATCH 25/54] Lazy CUDA IPC handle opening --- src/include/registered_memory.hpp | 1 + src/registered_memory.cc | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 1c37ff04b..88c1005d5 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -25,6 +25,7 @@ struct TransportInfo struct RegisteredMemory::Impl { void* data; + bool dataInitialized; size_t size; int rank; TransportFlags transports; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 516a4c64c..470e7c104 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -7,7 +7,7 @@ namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) - : data(data), size(size), rank(rank), transports(transports) + : data(data), dataInitialized(true), size(size), rank(rank), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; @@ -57,6 +57,18 @@ MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; void* RegisteredMemory::data() { + if (!pimpl->dataInitialized) { + if (pimpl->transports.has(Transport::CudaIpc)) { + auto entry = pimpl->getTransportInfo(Transport::CudaIpc); + CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); + } + else + { + pimpl->data = nullptr; + } + pimpl->dataInitialized = true; + } return pimpl->data; } @@ -141,11 +153,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) throw std::runtime_error("Deserialization failed"); } - if (transports.has(Transport::CudaIpc)) { - auto entry = getTransportInfo(Transport::CudaIpc); - CUDATHROW(cudaIpcOpenMemHandle(&data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); - INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); - } + dataInitialized = false; } } // namespace mscclpp From 821ba7a5281a4cbaf8102030d973d5d269155400 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Fri, 28 Apr 2023 00:30:36 +0000 Subject: [PATCH 26/54] Fix compilation --- src/registered_memory.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 470e7c104..3fae7a963 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -61,7 +61,7 @@ void* RegisteredMemory::data() if (pimpl->transports.has(Transport::CudaIpc)) { auto entry = pimpl->getTransportInfo(Transport::CudaIpc); CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); - INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", data); + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", pimpl->data); } else { From cbefe38fd40f4d9acbed6813e48465d8ca569be7 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 28 Apr 2023 09:12:21 +0000 Subject: [PATCH 27/54] aad conn write test --- src/communicator.cc | 1 + src/include/connection.hpp | 4 ++-- tests/communicator_test_cpp.cc | 43 ++++++++++++++++++++++++++++++---- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 79e45f8db..359368626 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -21,6 +21,7 @@ Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_( INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); rankToHash_[bootstrap->getRank()] = hostHash; bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); + comm->rank = bootstrap->getRank(); } Communicator::Impl::~Impl() diff --git a/src/include/connection.hpp b/src/include/connection.hpp index f957c8a10..42ca6d47a 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -13,8 +13,8 @@ namespace mscclpp { class ConnectionBase : public Connection { public: - virtual void startSetup(std::shared_ptr bootstrap){}; - virtual void endSetup(std::shared_ptr bootstrap){}; + virtual void startSetup(std::shared_ptr){}; + virtual void endSetup(std::shared_ptr){}; }; class CudaIpcConnection : public ConnectionBase diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 7fccf57bd..a0b12e431 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -35,14 +35,17 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Communicator initialization passed" << std::endl; + std::vector> connections; auto myIbDevice = findIb(rank % nranksPerNode); for (int i = 0; i < worldSize; i++) { if (i != rank) { + std::shared_ptr conn; if (i / nranksPerNode == rank / nranksPerNode) { - auto connect = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); + conn = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); } else { - auto connect = communicator->connect(i, 0, myIbDevice); + conn = communicator->connect(i, 0, myIbDevice); } + connections.push_back(conn); } } communicator->connectionSetup(); @@ -63,20 +66,52 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) bootstrap->send(serialized.data(), serializedSize, i, 1); } } + std::vector registeredMemories; for (int i = 0; i < worldSize; i++) { if (i != rank){ int deserializedSize; bootstrap->recv(&deserializedSize, sizeof(int), i, 0); std::vector deserialized(deserializedSize); bootstrap->recv(deserialized.data(), deserializedSize, i, 1); - // auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); + auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); + registeredMemories.push_back(std::move(deserializedRegisteredMemory)); } } + if (bootstrap->getRank() == 0) + std::cout << "Memory registration passed" << std::endl; + + assert(size % worldSize == 0); + size_t writeSize = size / worldSize; + size_t dataCount = size / sizeof(int); + // std::vector hostBuffer(dataCount, 0); + std::shared_ptr hostBuffer(new int[dataCount]); + for (int i = 0; i < dataCount; i++) { + hostBuffer[i] = rank; + } + CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); + + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + int peerRankIndex = i < rank ? i : i - 1; + auto conn = connections[peerRankIndex]; + conn->write(registeredMemories[peerRankIndex], rank * writeSize, registeredMemory, rank * writeSize, writeSize); + } + } + CUDATHROW(cudaDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost)); + size_t dataPerRank = writeSize / sizeof(int); + for (int i = 0; i < dataCount; i++) { + if (hostBuffer[i] != i / dataPerRank) { + throw std::runtime_error("Data mismatch, connection write failed"); + } + } if (bootstrap->getRank() == 0) - std::cout << "Memory registeration passed" << std::endl; + std::cout << "Connection write passed" << std::endl; + CUDATHROW(cudaFree(devicePtr)); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; } From 750c40b98719e9ae97b1c7c402020f47ee08f9a7 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 28 Apr 2023 10:48:56 +0000 Subject: [PATCH 28/54] Fix --- src/communicator.cc | 4 ++-- src/registered_memory.cc | 2 +- tests/communicator_test_cpp.cc | 20 ++++++++++++-------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 359368626..df213f8e9 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -21,7 +21,6 @@ Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_( INFO(MSCCLPP_INIT, "Host hash: %lx", hostHash); rankToHash_[bootstrap->getRank()] = hostHash; bootstrap->allGather(rankToHash_.data(), sizeof(uint64_t)); - comm->rank = bootstrap->getRank(); } Communicator::Impl::~Impl() @@ -61,7 +60,8 @@ MSCCLPP_API_CPP void Communicator::bootstrapBarrier() MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) { - return RegisteredMemory(std::make_shared(ptr, size, pimpl->comm->rank, transports, *pimpl)); + return RegisteredMemory( + std::make_shared(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl)); } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 3fae7a963..e298aee5f 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -77,7 +77,7 @@ size_t RegisteredMemory::size() return pimpl->size; } -int RegisteredMemory::rank() +MSCCLPP_API_CPP int RegisteredMemory::rank() { return pimpl->rank; } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index a0b12e431..c1e812cd0 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -5,6 +5,7 @@ #include #include #include +#include #define CUDATHROW(cmd) \ do { \ @@ -35,7 +36,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Communicator initialization passed" << std::endl; - std::vector> connections; + std::unordered_map> connections; auto myIbDevice = findIb(rank % nranksPerNode); for (int i = 0; i < worldSize; i++) { if (i != rank) { @@ -45,7 +46,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) } else { conn = communicator->connect(i, 0, myIbDevice); } - connections.push_back(conn); + connections[i] = conn; } } communicator->connectionSetup(); @@ -66,7 +67,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) bootstrap->send(serialized.data(), serializedSize, i, 1); } } - std::vector registeredMemories; + std::unordered_map registeredMemories; for (int i = 0; i < worldSize; i++) { if (i != rank){ int deserializedSize; @@ -74,14 +75,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) std::vector deserialized(deserializedSize); bootstrap->recv(deserialized.data(), deserializedSize, i, 1); auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); - registeredMemories.push_back(std::move(deserializedRegisteredMemory)); + registeredMemories.insert({deserializedRegisteredMemory.rank(), deserializedRegisteredMemory}); } } + MPI_Barrier(MPI_COMM_WORLD); if (bootstrap->getRank() == 0) std::cout << "Memory registration passed" << std::endl; - assert(size % worldSize == 0); + assert((size / sizeof(int)) % worldSize == 0); size_t writeSize = size / worldSize; size_t dataCount = size / sizeof(int); // std::vector hostBuffer(dataCount, 0); @@ -91,11 +93,13 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) } CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); + MPI_Barrier(MPI_COMM_WORLD); for (int i = 0; i < worldSize; i++) { if (i != rank) { - int peerRankIndex = i < rank ? i : i - 1; - auto conn = connections[peerRankIndex]; - conn->write(registeredMemories[peerRankIndex], rank * writeSize, registeredMemory, rank * writeSize, writeSize); + auto& conn = connections.at(i); + auto& peerMemory = registeredMemories.at(i); + // printf("write to rank: %d, rank is %d\n", peerMemory.rank(), rank); + conn->write(peerMemory, rank * writeSize, registeredMemory, rank * writeSize, writeSize); } } CUDATHROW(cudaDeviceSynchronize()); From 04e878489df3136b9acb1b6283c2022361339b8d Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Fri, 28 Apr 2023 22:50:38 +0000 Subject: [PATCH 29/54] Work on a channel service --- Makefile | 2 +- src/basic_proxy_handler.cc | 31 ---- src/communicator.cc | 13 +- src/connection.cc | 16 +- src/epoch.cc | 32 ++-- src/host_connection.cc | 96 ------------ src/include/channel.hpp | 280 ++++++++++++++++------------------- src/include/communicator.hpp | 1 - src/include/connection.hpp | 11 +- src/include/epoch.hpp | 34 ++--- src/include/mscclpp.hpp | 32 ++-- src/include/mscclppfifo.hpp | 7 +- src/include/proxy.hpp | 2 +- tests/allgather_test_cpp.cu | 26 ++-- 14 files changed, 227 insertions(+), 356 deletions(-) delete mode 100644 src/basic_proxy_handler.cc delete mode 100644 src/host_connection.cc diff --git a/Makefile b/Makefile index 41896041f..782129c09 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) -LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) +LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc epoch.cc) #LIBSRCS += $(addprefix src/,fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) diff --git a/src/basic_proxy_handler.cc b/src/basic_proxy_handler.cc deleted file mode 100644 index 424701315..000000000 --- a/src/basic_proxy_handler.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "basic_proxy_handler.hpp" - -namespace mscclpp { - -ProxyHandler makeBasicProxyHandler(Communicator::Impl& comm) -{ - return [&comm](ProxyTrigger triggerRaw) { - ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); - HostConnection& conn = *comm.connections.at(trigger->fields.connId); - - auto result = ProxyHandlerResult::Continue; - - if (trigger->fields.type & mscclppData) { - conn.put(trigger->fields.dstBufferHandle, trigger->fields.dstOffset, trigger->fields.srcBufferHandle, - trigger->fields.srcOffset, trigger->fields.size); - } - - if (trigger->fields.type & mscclppFlag) { - conn.signal(); - } - - if (trigger->fields.type & mscclppSync) { - conn.flush(); - result = ProxyHandlerResult::FlushFifoTailAndContinue; - } - - return result; - }; -} - -} // namespace mscclpp diff --git a/src/communicator.cc b/src/communicator.cc index df213f8e9..21faeaee2 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -1,13 +1,11 @@ #include #include "api.h" -#include "basic_proxy_handler.hpp" #include "checks.hpp" #include "comm.h" #include "communicator.hpp" #include "connection.hpp" #include "debug.h" -#include "host_connection.hpp" #include "mscclpp.hpp" #include "registered_memory.hpp" #include "utils.h" @@ -48,14 +46,9 @@ MSCCLPP_API_CPP Communicator::Communicator(std::shared_ptr bootst { } -MSCCLPP_API_CPP void Communicator::bootstrapAllGather(void* data, int size) +MSCCLPP_API_CPP std::shared_ptr Communicator::bootstrapper() { - mscclppBootstrapAllGather(pimpl->comm, data, size); -} - -MSCCLPP_API_CPP void Communicator::bootstrapBarrier() -{ - mscclppBootstrapBarrier(pimpl->comm); + return pimpl->bootstrap_; } MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, TransportFlags transports) @@ -77,7 +70,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank << pimpl->rankToHash_[pimpl->bootstrap_->getRank()] << ")"; throw std::runtime_error(ss.str()); } - auto cudaIpcConn = std::make_shared(); + auto cudaIpcConn = std::make_shared(remoteRank, tag); conn = cudaIpcConn; INFO(MSCCLPP_P2P, "Cuda IPC connection between rank %d(%lx) and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], remoteRank, diff --git a/src/connection.cc b/src/connection.cc index 75a6ba797..4f8a45155 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -20,9 +20,17 @@ std::shared_ptr Connection::getRegisteredMemoryImpl(Regi return mem.pimpl; } +// ConnectionBase + +ConnectionBase::ConnectionBase(int remoteRank, int tag) : remoteRank_(remoteRank), tag_(tag) {} + +int ConnectionBase::remoteRank() { return remoteRank_; } + +int ConnectionBase::tag() { return tag_; } + // CudaIpcConnection -CudaIpcConnection::CudaIpcConnection() +CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag) { cudaStreamCreate(&stream); } @@ -64,7 +72,7 @@ void CudaIpcConnection::flush() // IBConnection IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) - : remoteRank_(remoteRank), tag_(tag), transport_(transport), remoteTransport_(Transport::Unknown) + : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown) { qp = commImpl.getIbContext(transport)->createQp(); } @@ -134,13 +142,13 @@ void IBConnection::flush() void IBConnection::startSetup(std::shared_ptr bootstrap) { - bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); + bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank(), tag()); } void IBConnection::endSetup(std::shared_ptr bootstrap) { IbQpInfo qpInfo; - bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); + bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank(), tag()); qp->rtr(qpInfo); qp->rts(); } diff --git a/src/epoch.cc b/src/epoch.cc index f6c827311..7bcab9c89 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -1,27 +1,27 @@ #include "epoch.hpp" #include "checks.hpp" +#include "alloc.h" namespace mscclpp { -struct Epoch::Impl -{ - DeviceEpoch deviceEpoch; +Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) : connection_(connection) { + MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1)); + MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1)); - Impl() - { - MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.localSignalEpochId, 1)); - MSCCLPPTHROW(mscclppCudaCalloc(&deviceEpoch.waitEpochId, 1)); - } + localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport()); + communicator.bootstrapper()->send(localEpochIdsRegMem_.serialize(), connection->remoteRank(), connection->tag()); + std::vector serializedRemoteEpochIds; + communicator.bootstrapper()->recv(serializedRemoteEpochIds, connection->remoteRank(), connection->tag()); + remoteEpochIdsRegMem_ = RegisteredMemory::deserialize(serializedRemoteEpochIds); +} - ~Impl() - { - MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.localSignalEpochId)); - MSCCLPPTHROW(mscclppCudaFree(deviceEpoch.waitEpochId)); - } -}; +Epoch::~Epoch() { + MSCCLPPTHROW(mscclppCudaFree(&device_.epochIds_)); + MSCCLPPTHROW(mscclppCudaFree(&device_.expectedInboundEpochId_)); +} -Epoch::Epoch() : pimpl(std::make_unique()) -{ +void Epoch::signal() { + connection_->write(localEpochIdsRegMem_, offsetof(EpochIds, outbound_), remoteEpochIdsRegMem_, offsetof(EpochIds, inboundReplica_), sizeof(device_.epochIds_)); } } // namespace mscclpp \ No newline at end of file diff --git a/src/host_connection.cc b/src/host_connection.cc deleted file mode 100644 index e33069e25..000000000 --- a/src/host_connection.cc +++ /dev/null @@ -1,96 +0,0 @@ -#include "host_connection.hpp" -#include "api.h" -#include "comm.h" -#include "communicator.hpp" -#include "mscclpp.h" -#include "mscclppfifo.h" - -namespace mscclpp { - -HostConnection::Impl::Impl(Communicator* comm, mscclppConn* conn) : comm(comm), conn(conn) -{ - this->hostConn = conn->hostConn; -} - -HostConnection::Impl::~Impl() -{ - // TODO: figure out memory ownership. Does this deallocate the mscclppHostConn? Likely not. -} - -MSCCLPP_API_CPP HostConnection::~HostConnection() = default; - -MSCCLPP_API_CPP HostConnection::HostConnection(std::unique_ptr p) : pimpl(std::move(p)) -{ -} - -MSCCLPP_API_CPP int HostConnection::getId() -{ - return pimpl->conn->connId; -} - -MSCCLPP_API_CPP BufferHandle HostConnection::registerBuffer(void* data, uint64_t size) -{ - BufferHandle result; - static_assert(sizeof(BufferHandle) == sizeof(mscclppBufferHandle_t)); - mscclppRegisterBufferForConnection(pimpl->comm->pimpl->comm, pimpl->conn->connId, data, size, - reinterpret_cast(&result)); - return result; -} - -MSCCLPP_API_CPP int HostConnection::numLocalBuffers() -{ - return pimpl->conn->bufferRegistrations.size() - 1; -} - -MSCCLPP_API_CPP BufferHandle HostConnection::getLocalBuffer(int index) -{ - return index + 1; -} - -MSCCLPP_API_CPP int HostConnection::numRemoteBuffers() -{ - return pimpl->conn->remoteBufferRegistrations.size() - 1; -} - -MSCCLPP_API_CPP BufferHandle HostConnection::getRemoteBuffer(int index) -{ - return index + 1; -} - -MSCCLPP_API_CPP ConnectionEpoch HostConnection::getEpoch() -{ - ConnectionEpoch epoch; - static_assert(sizeof(SignalEpochId) == sizeof(mscclppDevConnSignalEpochId)); - epoch.localSignalEpochId = reinterpret_cast(pimpl->conn->devConn->localSignalEpochId); - epoch.remoteSignalEpochId = reinterpret_cast(pimpl->conn->devConn->remoteSignalEpochId); - epoch.waitEpochId = pimpl->conn->devConn->waitEpochId; - return epoch; -} - -MSCCLPP_API_CPP DeviceProxyFifo HostConnection::getDeviceFifo() -{ - return pimpl->comm->pimpl->proxy.fifo().toDevice(); -} - -MSCCLPP_API_CPP void HostConnection::put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, - uint64_t size) -{ - pimpl->hostConn->put(dst, dstOffset, src, srcOffset, size); -} - -MSCCLPP_API_CPP void HostConnection::signal() -{ - pimpl->hostConn->signal(); -} - -MSCCLPP_API_CPP void HostConnection::flush() -{ - pimpl->hostConn->flush(); -} - -MSCCLPP_API_CPP void HostConnection::wait() -{ - pimpl->hostConn->wait(); -} - -} // namespace mscclpp \ No newline at end of file diff --git a/src/include/channel.hpp b/src/include/channel.hpp index 2303a57cb..ace576614 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -4,26 +4,39 @@ #include "epoch.hpp" #include "mscclpp.hpp" #include "proxy.hpp" +#include "mscclppfifo.hpp" namespace mscclpp { +namespace channel { -// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered. -// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem. -#define MSCCLPP_PROXY_FIFO_SIZE 128 -#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4 +// A Channel pairs a Connection with an Epoch +class Channel +{ +public: + Channel(std::shared_ptr connection) : connection_(connection), epoch_(std::make_shared()) {}; + + Connection& connection() { return *connection_; } + Epoch& epoch() { return *epoch_; } -using ChannelTriggerType = uint64_t; -const ChannelTriggerType channelTriggerData = 0x1; -const ChannelTriggerType channelTriggerFlag = 0x2; -const ChannelTriggerType channelTriggerSync = 0x4; +private: + std::shared_ptr connection_; + std::shared_ptr epoch_; +}; + +using ChannelId = uint32_t; + +using TriggerType = uint64_t; +const TriggerType TriggerData = 0x1; +const TriggerType TriggerFlag = 0x2; +const TriggerType TriggerSync = 0x4; // This is just a numeric ID. Each HostConnection will have an internal array indexed by these handles // mapping to the actual -using BufferHandle = uint32_t; +using MemoryId = uint32_t; #define MSCCLPP_BITS_SIZE 32 #define MSCCLPP_BITS_OFFSET 32 -#define MSCCLPP_BITS_BUFFER_HANDLE 8 +#define MSCCLPP_BITS_REGMEM_HANDLE 8 #define MSCCLPP_BITS_TYPE 3 #define MSCCLPP_BITS_CONNID 10 @@ -39,11 +52,11 @@ union ChannelTrigger { uint64_t : (64 - MSCCLPP_BITS_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment // second 64 bits: value[1] uint64_t dstOffset : MSCCLPP_BITS_OFFSET; - uint64_t srcBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; - uint64_t dstBufferHandle : MSCCLPP_BITS_BUFFER_HANDLE; + uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; uint64_t type : MSCCLPP_BITS_TYPE; - uint64_t connId : MSCCLPP_BITS_CONNID; - uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_BUFFER_HANDLE - MSCCLPP_BITS_BUFFER_HANDLE - + uint64_t chanId : MSCCLPP_BITS_CONNID; + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_TYPE); // ensure 64-bit alignment } fields; @@ -54,12 +67,12 @@ union ChannelTrigger { __device__ ChannelTrigger(ProxyTrigger value) : value(value) { } - __device__ ChannelTrigger(ChannelTriggerType type, BufferHandle dst, uint64_t dstOffset, BufferHandle src, + __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size, int connectionId) { value.fst = ((srcOffset << MSCCLPP_BITS_SIZE) + size); - value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_BUFFER_HANDLE) + dst) - << MSCCLPP_BITS_BUFFER_HANDLE) + + value.snd = ((((((((connectionId << MSCCLPP_BITS_TYPE) + (uint64_t)type) << MSCCLPP_BITS_REGMEM_HANDLE) + dst) + << MSCCLPP_BITS_REGMEM_HANDLE) + src) << MSCCLPP_BITS_OFFSET) + dstOffset); @@ -67,114 +80,24 @@ union ChannelTrigger { #endif // __CUDACC__ }; -struct ConnectionEpoch -{ -#ifdef __CUDACC__ - __forceinline__ __device__ void wait() - { - (*waitEpochId) += 1; - while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) - ; - } - - __forceinline__ __device__ void epochIncrement() - { - *(volatile uint64_t*)&(localSignalEpochId->device) += 1; - } -#endif // __CUDACC__ - - SignalEpochId* localSignalEpochId; - // used by the signal() function directly from gpu - SignalEpochId* remoteSignalEpochId; - - // every wait(), increments this and then the gpu waits for either: - // 1) localSignalEpochId->proxy to be >= this in case of a proxy thread - // 2) remoteSignalEpochId->device to be >= this in case of a gpu thread - uint64_t* waitEpochId; -}; - -class HostConnection -{ - struct Impl; - -public: - /* HostConnection can not be constructed from user code and must instead be created through Communicator::connect */ - HostConnection(std::unique_ptr); - - ~HostConnection(); - - void write(); - - int getId(); - - /* Get the number of times registerBuffer(...) was called. - * - * Returns: the number of buffers registered - */ - int numLocalBuffers(); - - /* Get the BufferHandle returned by a call to registerBuffer(...) as identified by the index - * - * Inputs: - * index: the index of the handle to get - * - * Returns: a handle to the buffer - */ - BufferHandle getLocalBuffer(int index); - - /* Get the number of times registerBuffer(...) was called on the remote peer. - * - * Returns: the number of buffers registered on the remote peer - */ - int numRemoteBuffers(); - - /* Get the BufferHandle returned by a call to registerBuffer(...) on the remote peer as identified by the index - * - * Inputs: - * index: the index of the handle to get - * - * Returns: a handle to the buffer on the remote peer - */ - BufferHandle getRemoteBuffer(int index); - - ConnectionEpoch getEpoch(); - - DeviceProxyFifo getDeviceFifo(); - - void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, uint64_t size); - - void signal(); - - void flush(); - - void wait(); - -private: - std::unique_ptr pimpl; - friend class Communicator; -}; - -struct DeviceConnection +struct DeviceChannel { - DeviceConnection() = default; + DeviceChannel() = default; - DeviceConnection(HostConnection& hostConn) - : connectionId(hostConn.getId()), epoch(hostConn.getEpoch()), fifo(hostConn.getDeviceFifo()) - { - } + DeviceChannel(ChannelId channelId, DeviceEpoch epoch, DeviceProxyFifo fifo) : channelId_(channelId), epoch_(epoch), fifo_(fifo) {} - DeviceConnection(const DeviceConnection& other) = default; + DeviceChannel(const DeviceChannel& other) = default; - DeviceConnection& operator=(DeviceConnection& other) = default; + DeviceChannel& operator=(DeviceChannel& other) = default; #ifdef __CUDACC__ - __forceinline__ __device__ void put(BufferHandle dst, uint64_t dstOffset, BufferHandle src, uint64_t srcOffset, + __forceinline__ __device__ void put(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) { - fifo.push(ChannelTrigger(channelTriggerData, dst, dstOffset, src, srcOffset, size, connectionId).value); + fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, channelId_).value); } - __forceinline__ __device__ void put(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + __forceinline__ __device__ void put(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { put(dst, offset, src, offset, size); } @@ -182,36 +105,36 @@ struct DeviceConnection __forceinline__ __device__ void signal() { epochIncrement(); - fifo.push(ChannelTrigger(channelTriggerFlag, 0, 0, 0, 0, 1, connectionId).value); + fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, channelId_).value); } - __forceinline__ __device__ void putWithSignal(BufferHandle dst, uint64_t dstOffset, BufferHandle src, + __forceinline__ __device__ void putWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) { epochIncrement(); - fifo.push( - ChannelTrigger(channelTriggerData | channelTriggerFlag, dst, dstOffset, src, srcOffset, size, connectionId) + fifo_.push( + ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, channelId_) .value); } - __forceinline__ __device__ void putWithSignal(BufferHandle dst, BufferHandle src, uint64_t offset, uint64_t size) + __forceinline__ __device__ void putWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { putWithSignal(dst, offset, src, offset, size); } - __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, uint64_t dstOffset, BufferHandle src, + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, uint64_t size) { epochIncrement(); - uint64_t curFifoHead = fifo.push(ChannelTrigger(channelTriggerData | channelTriggerFlag | channelTriggerSync, dst, - dstOffset, src, srcOffset, size, connectionId) + uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerData | TriggerFlag | TriggerSync, dst, + dstOffset, src, srcOffset, size, channelId_) .value); - while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && - *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) + while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo_.tailReplica <= curFifoHead) ; } - __forceinline__ __device__ void putWithSignalAndFlush(BufferHandle dst, BufferHandle src, uint64_t offset, + __forceinline__ __device__ void putWithSignalAndFlush(MemoryId dst, MemoryId src, uint64_t offset, uint64_t size) { putWithSignalAndFlush(dst, offset, src, offset, size); @@ -219,53 +142,103 @@ struct DeviceConnection __forceinline__ __device__ void flush() { - uint64_t curFifoHead = fifo.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, connectionId).value); + uint64_t curFifoHead = fifo_.push(ChannelTrigger(mscclppSync, 0, 0, 0, 0, 1, channelId_).value); // we need to wait for two conditions to be met to ensure the CPU is done flushing. (1) wait for the tail // to go pass by curFifoHead (this is safety net) and (2) wait for the work element value to change to 0. - while (*(volatile uint64_t*)&fifo.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && - *(volatile uint64_t*)fifo.tailReplica <= curFifoHead) + while (*(volatile uint64_t*)&fifo_.triggers[curFifoHead % MSCCLPP_PROXY_FIFO_SIZE] != 0 && + *(volatile uint64_t*)fifo_.tailReplica <= curFifoHead) ; } __forceinline__ __device__ void wait() { - epoch.wait(); + epoch_.wait(); } __forceinline__ __device__ void epochIncrement() { - epoch.epochIncrement(); + epoch_.epochIncrement(); } #endif // __CUDACC__ - int connectionId; + ChannelId channelId_; - ConnectionEpoch epoch; + DeviceEpoch epoch_; // this is a concurrent fifo which is multiple threads from the device // can produce for and the sole proxy thread consumes it. - DeviceProxyFifo fifo; + DeviceProxyFifo fifo_; }; -struct SimpleDeviceConnection -{ - SimpleDeviceConnection() = default; +class DeviceChannelService; - SimpleDeviceConnection(HostConnection& hostConn) : devConn(hostConn) - { - dst = hostConn.getRemoteBuffer(0); - src = hostConn.getLocalBuffer(0); +inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService); + +class DeviceChannelService { +public: + DeviceChannelService() : proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {} + + ChannelId addChannel(std::shared_ptr connection) { + channels_.push_back(Channel(connection)); + return channels_.size() - 1; } - SimpleDeviceConnection(const SimpleDeviceConnection& other) = default; + MemoryId addMemory(RegisteredMemory memory) { + memories_.push_back(memory); + return memories_.size() - 1; + } + + Channel channel(ChannelId id) { return channels_[id]; } + DeviceChannel deviceChannel(ChannelId id) { return DeviceChannel(id, channels_[id].epoch().deviceEpoch(), proxy_.fifo().deviceFifo()); } + + void startProxy() { proxy_.start(); } + void stopProxy() { proxy_.stop(); } + +private: + std::vector channels_; + std::vector memories_; + Proxy proxy_; + + ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) { + ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); + Channel& channel = channels_[trigger->fields.chanId]; + + auto result = ProxyHandlerResult::Continue; + + if (trigger->fields.type & TriggerData) { + RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; + RegisteredMemory& src = memories_[trigger->fields.srcMemoryId]; + channel.connection().write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size); + } + + if (trigger->fields.type & TriggerFlag) { + channel.epoch().signal(); + } + + if (trigger->fields.type & TriggerSync) { + channel.connection().flush(); + result = ProxyHandlerResult::FlushFifoTailAndContinue; + } + + return result; + } +}; + +struct SimpleDeviceChannel +{ + SimpleDeviceChannel() = default; + + SimpleDeviceChannel(DeviceChannel devChan, MemoryId dst, MemoryId src) : devChan_(devChan), dst_(dst), src_(src) {} + + SimpleDeviceChannel(const SimpleDeviceChannel& other) = default; - SimpleDeviceConnection& operator=(SimpleDeviceConnection& other) = default; + SimpleDeviceChannel& operator=(SimpleDeviceChannel& other) = default; #ifdef __CUDACC__ __forceinline__ __device__ void put(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { - devConn.put(dst, dstOffset, src, srcOffset, size); + devChan_.put(dst_, dstOffset, src_, srcOffset, size); } __forceinline__ __device__ void put(uint64_t offset, uint64_t size) @@ -275,12 +248,12 @@ struct SimpleDeviceConnection __forceinline__ __device__ void signal() { - devConn.signal(); + devChan_.signal(); } __forceinline__ __device__ void putWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { - devConn.putWithSignal(dst, dstOffset, src, srcOffset, size); + devChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size); } __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) @@ -290,7 +263,7 @@ struct SimpleDeviceConnection __forceinline__ __device__ void putWithSignalAndFlush(uint64_t dstOffset, uint64_t srcOffset, uint64_t size) { - devConn.putWithSignalAndFlush(dst, dstOffset, src, srcOffset, size); + devChan_.putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size); } __forceinline__ __device__ void putWithSignalAndFlush(uint64_t offset, uint64_t size) @@ -300,26 +273,27 @@ struct SimpleDeviceConnection __forceinline__ __device__ void flush() { - devConn.flush(); + devChan_.flush(); } __forceinline__ __device__ void wait() { - devConn.wait(); + devChan_.wait(); } __forceinline__ __device__ void epochIncrement() { - devConn.epochIncrement(); + devChan_.epochIncrement(); } #endif // __CUDACC__ - DeviceConnection devConn; - BufferHandle dst; - BufferHandle src; + DeviceChannel devChan_; + MemoryId dst_; + MemoryId src_; }; +} // namespace channel } // namespace mscclpp #endif // MSCCLPP_CHANNEL_HPP_ diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 25fface7f..b9b28f896 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -1,7 +1,6 @@ #ifndef MSCCL_COMMUNICATOR_HPP_ #define MSCCL_COMMUNICATOR_HPP_ -#include "channel.hpp" #include "ib.hpp" #include "mscclpp.h" #include "mscclpp.hpp" diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 42ca6d47a..b28b58908 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -12,7 +12,14 @@ namespace mscclpp { class ConnectionBase : public Connection { + int remoteRank_; + int tag_; public: + ConnectionBase(int remoteRank, int tag); + + int remoteRank() override; + int tag() override; + virtual void startSetup(std::shared_ptr){}; virtual void endSetup(std::shared_ptr){}; }; @@ -22,7 +29,7 @@ class CudaIpcConnection : public ConnectionBase cudaStream_t stream; public: - CudaIpcConnection(); + CudaIpcConnection(int remoteRank, int tag); ~CudaIpcConnection(); @@ -38,8 +45,6 @@ class CudaIpcConnection : public ConnectionBase class IBConnection : public ConnectionBase { - int remoteRank_; - int tag_; Transport transport_; Transport remoteTransport_; IbQp* qp; diff --git a/src/include/epoch.hpp b/src/include/epoch.hpp index fd25b51fe..2c6e3296d 100644 --- a/src/include/epoch.hpp +++ b/src/include/epoch.hpp @@ -5,14 +5,10 @@ namespace mscclpp { -struct alignas(16) SignalEpochId +struct alignas(16) EpochIds { - // every signal(), increaments this and either: - // 1) proxy thread pushes it to the remote peer's localSignalEpochId->proxy - // 2) gpu thread directly writes it to remoteSignalEpochId->device - uint64_t device; - // signal() function triggers the cpu proxy thread to write to it - uint64_t proxy; + uint64_t outbound_; + uint64_t inboundReplica_; }; struct DeviceEpoch @@ -20,34 +16,36 @@ struct DeviceEpoch #ifdef __CUDACC__ __forceinline__ __device__ void wait() { - (*waitEpochId) += 1; - while (*(volatile uint64_t*)&(localSignalEpochId->proxy) < (*waitEpochId)) - ; + (*expectedInboundEpochId_) += 1; + while (*(volatile uint64_t*)&(epochIds_->inboundReplica_) < (*expectedInboundEpochId_)); } __forceinline__ __device__ void epochIncrement() { - *(volatile uint64_t*)&(localSignalEpochId->device) += 1; + *(volatile uint64_t*)&(epochIds_->outbound_) += 1; } #endif // __CUDACC__ - SignalEpochId* localSignalEpochId; - SignalEpochId* remoteSignalEpochId; - uint64_t* waitEpochId; + EpochIds* epochIds_; + uint64_t* expectedInboundEpochId_; }; class Epoch { - struct Impl; - std::unique_ptr pimpl; + std::shared_ptr connection_; + DeviceEpoch device_; + RegisteredMemory localEpochIdsRegMem_; + RegisteredMemory remoteEpochIdsRegMem_; public: - Epoch(); + Epoch(Communicator& communicator, std::shared_ptr connection); ~Epoch(); void signal(); - DeviceEpoch& getDeviceEpoch(); + DeviceEpoch deviceEpoch() { + return device_; + } }; } // namespace mscclpp diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 8a85ebc68..fde631801 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -30,6 +30,20 @@ class BaseBootstrap virtual void recv(void* data, int size, int peer, int tag) = 0; virtual void allGather(void* allData, int size) = 0; virtual void barrier() = 0; + + // TODO: move implementations of these helpers out of this header + void send(const std::vector& data, int peer, int tag) + { + send((void*)data.size(), sizeof(size_t), peer, tag); + send((void*)data.data(), data.size(), peer, tag); + } + void recv(std::vector& data, int peer, int tag) + { + size_t size; + recv((void*)&size, sizeof(size_t), peer, tag); + data.resize(size); + recv((void*)data.data(), data.size(), peer, tag); + } }; class Bootstrap : public BaseBootstrap @@ -223,9 +237,11 @@ class Connection; class RegisteredMemory { struct Impl; + // A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated lazily. std::shared_ptr pimpl; public: + RegisteredMemory() = default; RegisteredMemory(std::shared_ptr pimpl); ~RegisteredMemory(); @@ -249,6 +265,10 @@ class Connection virtual void flush() = 0; + virtual int remoteRank() = 0; + + virtual int tag() = 0; + virtual Transport transport() = 0; virtual Transport remoteTransport() = 0; @@ -269,16 +289,8 @@ class Communicator ~Communicator(); - /* Ring-based AllGather through the bootstrap socket. - * - * Inputs: - * data: data array to be gathered where `[r*size, (r+1)*size)` is the data for rank `r` - * size: data size per rank - */ - void bootstrapAllGather(void* data, int size); - - /* A no-op function that is used to synchronize all processes via a bootstrap allgather*/ - void bootstrapBarrier(); + /* Return the bootstrapper held by this communicator. */ + std::shared_ptr bootstrapper(); /* Register a region of GPU memory for use in this communicator. * diff --git a/src/include/mscclppfifo.hpp b/src/include/mscclppfifo.hpp index 7e2820b00..c13e4fb8a 100644 --- a/src/include/mscclppfifo.hpp +++ b/src/include/mscclppfifo.hpp @@ -7,6 +7,11 @@ namespace mscclpp { +// For every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER, a flush of the tail to device memory is triggered. +// As long as MSCCLPP_PROXY_FIFO_SIZE is large enough, having a stale tail is not a problem. +#define MSCCLPP_PROXY_FIFO_SIZE 128 +#define MSCCLPP_PROXY_FIFO_FLUSH_COUNTER 4 + struct alignas(16) ProxyTrigger { uint64_t fst, snd; @@ -60,7 +65,7 @@ class HostProxyFifo void flushTail(bool sync = false); - DeviceProxyFifo toDevice(); + DeviceProxyFifo deviceFifo(); private: struct Impl; diff --git a/src/include/proxy.hpp b/src/include/proxy.hpp index ac4116b31..f913beac7 100644 --- a/src/include/proxy.hpp +++ b/src/include/proxy.hpp @@ -1,7 +1,7 @@ #ifndef MSCCLPP_PROXY_HPP_ #define MSCCLPP_PROXY_HPP_ -#include +#include "mscclppfifo.hpp" #include #include diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index 908a24f4c..8fb54733b 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -1,5 +1,6 @@ #include "mscclpp.h" #include "mscclpp.hpp" +#include "channel.hpp" #ifdef MSCCLPP_USE_MPI_FOR_TESTS #include "mpi.h" @@ -48,9 +49,9 @@ static double getTime(void) return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } -__constant__ mscclpp::SimpleDeviceConnection constDevConns[16]; +__constant__ mscclpp::channel::SimpleDeviceConnection constDevConns[16]; -__device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, +__device__ void allgather0(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) { // this allgather is really simple and implemented as an alltoall @@ -70,7 +71,7 @@ __device__ void allgather0(mscclpp::SimpleDeviceConnection devConn, int rank, in devConn.wait(); } -__device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, +__device__ void localAllGather(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) { // this allgather algorithm works as follows: @@ -94,14 +95,14 @@ __device__ void localAllGather(mscclpp::SimpleDeviceConnection devConn, int rank } } -__device__ void allgather1(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, +__device__ void allgather1(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); } -__device__ void allgather2(mscclpp::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, +__device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { // this allgather is a pipelined and hierarchical one and only works for two nodes @@ -170,7 +171,7 @@ __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelem int warpId = threadIdx.x / 32; int remoteRank = (warpId < rank) ? warpId : warpId + 1; // Each warp is responsible for one of the remote ranks - mscclpp::SimpleDeviceConnection devConn = constDevConns[warpId]; + mscclpp::channel::SimpleDeviceConnection devConn = constDevConns[warpId]; if (kernel == 0) allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU); @@ -222,21 +223,24 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); - std::vector> hostConns; + mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); + mscclpp::channel::DeviceChannelService channelService; for (int r = 0; r < world_size; ++r) { if (r == rank) continue; - mscclpp::TransportType transportType; + mscclpp::Transport transport; const char* ibDev = ibDevStr.c_str(); if (rankToNode(r) == thisNode) { ibDev = NULL; - transportType = mscclpp::TransportType::P2P; + transportType = mscclpp::Transport::CudaIpc; } else { - transportType = mscclpp::TransportType::IB; + transportType = ibTransport; } // Connect with all other ranks - auto hostConn = comm.connect(r, 0, transportType, ibDev); + auto conn = comm.connect(r, 0, transportType); + channelService.addChannel(conn); + // TODO: WIP hostConn->registerBuffer(data_d, dataSize); hostConns.push_back(hostConn); } From 7d1f038181cc5f31c4b4610a8733c5b127c7a290 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Sat, 29 Apr 2023 05:16:33 +0000 Subject: [PATCH 30/54] fixes for ib send/recv tests --- src/connection.cc | 4 +++- tests/communicator_test_cpp.cc | 30 ++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 75a6ba797..439916ebe 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -103,7 +103,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem auto srcMr = srcTransportInfo.ibMr; qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, - /*signaled=*/false); + /*signaled=*/true); qp->postSend(); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } @@ -135,12 +135,14 @@ void IBConnection::flush() void IBConnection::startSetup(std::shared_ptr bootstrap) { bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank_, tag_); + bootstrap->send(&transport_, sizeof(transport_), remoteRank_, tag_); } void IBConnection::endSetup(std::shared_ptr bootstrap) { IbQpInfo qpInfo; bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank_, tag_); + bootstrap->recv(&remoteTransport_, sizeof(remoteTransport_), remoteRank_, tag_); qp->rtr(qpInfo); qp->rts(); } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index c1e812cd0..c4db0cf8e 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -79,7 +79,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) } } - MPI_Barrier(MPI_COMM_WORLD); + bootstrap->barrier(); if (bootstrap->getRank() == 0) std::cout << "Memory registration passed" << std::endl; @@ -93,24 +93,34 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) } CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); - MPI_Barrier(MPI_COMM_WORLD); + bootstrap->barrier(); for (int i = 0; i < worldSize; i++) { if (i != rank) { auto& conn = connections.at(i); auto& peerMemory = registeredMemories.at(i); // printf("write to rank: %d, rank is %d\n", peerMemory.rank(), rank); conn->write(peerMemory, rank * writeSize, registeredMemory, rank * writeSize, writeSize); + conn->flush(); } } - CUDATHROW(cudaDeviceSynchronize()); - MPI_Barrier(MPI_COMM_WORLD); - CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost)); - size_t dataPerRank = writeSize / sizeof(int); - for (int i = 0; i < dataCount; i++) { - if (hostBuffer[i] != i / dataPerRank) { - throw std::runtime_error("Data mismatch, connection write failed"); + bootstrap->barrier(); + // polling until it becomes ready + bool ready = false; + int niter = 0; + do { + ready = true; + CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost)); + size_t dataPerRank = writeSize / sizeof(int); + for (int i = 0; i < dataCount; i++) { + if (hostBuffer[i] != i / dataPerRank) { + ready = false; + } } - } + if (niter == 10000){ + throw std::runtime_error("Polling is stuck."); + } + niter++; + } while (!ready); if (bootstrap->getRank() == 0) std::cout << "Connection write passed" << std::endl; From 88426ad83a33e165894a1265bb59c4c121a1f5b3 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Mon, 1 May 2023 21:07:12 +0000 Subject: [PATCH 31/54] bug fix for ib memory registeration --- src/connection.cc | 2 +- src/include/registered_memory.hpp | 6 ++++-- src/registered_memory.cc | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 5d9f508af..5289ab596 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -102,7 +102,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem if (dstTransportInfo.ibLocal) { throw std::runtime_error("dst is local, which is not supported"); } - auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(remoteTransport()); + auto srcTransportInfo = getRegisteredMemoryImpl(src)->getTransportInfo(transport()); if (!srcTransportInfo.ibLocal) { throw std::runtime_error("src is remote, which is not supported"); } diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index 88c1005d5..e95507f1f 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -17,8 +17,10 @@ struct TransportInfo bool ibLocal; union { cudaIpcMemHandle_t cudaIpcHandle; - const IbMr* ibMr; - IbMrInfo ibMrInfo; + struct { + const IbMr* ibMr; + IbMrInfo ibMrInfo; + }; }; }; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index e298aee5f..1215c0e20 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -28,6 +28,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t const IbMr* mr = commImpl.getIbContext(ibTransport)->registerMr(data, size); transportInfo.ibMr = mr; transportInfo.ibLocal = true; + transportInfo.ibMrInfo = mr->getInfo(); this->transportInfos.push_back(transportInfo); }; if (transports.has(Transport::IB0)) From 8a5a7873e05b150f659386cc86874141a5e73ab1 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Mon, 1 May 2023 21:40:18 +0000 Subject: [PATCH 32/54] test bug fix --- tests/communicator_test_cpp.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index c4db0cf8e..78bffaac4 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -92,6 +92,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) hostBuffer[i] = rank; } CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); + CUDATHROW(cudaDeviceSynchronize()); bootstrap->barrier(); for (int i = 0; i < worldSize; i++) { @@ -122,6 +123,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) niter++; } while (!ready); + bootstrap->barrier(); if (bootstrap->getRank() == 0) std::cout << "Connection write passed" << std::endl; From 5b7e76cae41f6d3eeb58a5eed4bbd80120efa4b6 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Mon, 1 May 2023 22:25:14 +0000 Subject: [PATCH 33/54] all tests are passing with memory registeration --- src/connection.cc | 19 +++++++++++++++---- src/registered_memory.cc | 1 + 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 5289ab596..2cfa72055 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -1,3 +1,4 @@ +#include #include "connection.hpp" #include "checks.hpp" #include "infiniband/verbs.h" @@ -142,15 +143,25 @@ void IBConnection::flush() void IBConnection::startSetup(std::shared_ptr bootstrap) { - bootstrap->send(&qp->getInfo(), sizeof(qp->getInfo()), remoteRank(), tag()); - bootstrap->send(&transport_, sizeof(transport_), remoteRank(), tag()); + std::vector ibQpTransport; + std::copy_n(reinterpret_cast(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport)); + std::copy_n(reinterpret_cast(&transport_), sizeof(transport_), std::back_inserter(ibQpTransport)); + + bootstrap->send(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); } void IBConnection::endSetup(std::shared_ptr bootstrap) { + std::vector ibQpTransport(sizeof(IbQpInfo) + sizeof(Transport)); + bootstrap->recv(ibQpTransport.data(), ibQpTransport.size(), remoteRank(), tag()); + IbQpInfo qpInfo; - bootstrap->recv(&qpInfo, sizeof(qpInfo), remoteRank(), tag()); - bootstrap->recv(&remoteTransport_, sizeof(remoteTransport_), remoteRank(), tag()); + auto it = ibQpTransport.begin(); + std::copy_n(it, sizeof(qpInfo), reinterpret_cast(&qpInfo)); + it += sizeof(qpInfo); + std::copy_n(it, sizeof(remoteTransport_), reinterpret_cast(&remoteTransport_)); + it += sizeof(qpInfo); + qp->rtr(qpInfo); qp->rts(); } diff --git a/src/registered_memory.cc b/src/registered_memory.cc index 1215c0e20..abf17a8bc 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -30,6 +30,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t transportInfo.ibLocal = true; transportInfo.ibMrInfo = mr->getInfo(); this->transportInfos.push_back(transportInfo); + INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size); }; if (transports.has(Transport::IB0)) addIb(Transport::IB0); From 961f5b38ddf1cfe5eebfef40c4d5b81defb6daa4 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 2 May 2023 00:44:13 +0000 Subject: [PATCH 34/54] more debbuging info + testing 1000 memory registerations --- src/connection.cc | 3 + tests/communicator_test_cpp.cc | 181 ++++++++++++++++++++------------- 2 files changed, 111 insertions(+), 73 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 2cfa72055..e0c524195 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -61,6 +61,8 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register char* srcPtr = (char*)src.data(); CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream)); + INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size); + // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); } @@ -114,6 +116,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/true); qp->postSend(); + INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 78bffaac4..6f7aa3e16 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -23,6 +23,55 @@ mscclpp::Transport findIb(int localRank) return IBs[localRank]; } +void register_all_memories(std::unique_ptr& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemory){ + localMemory = communicator->registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice); + int serializedSize = 0; + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + auto serialized = localMemory.serialize(); + serializedSize = serialized.size(); + communicator->bootstrapper()->send(serialized.data(), serializedSize, i, 0); + } + } + if (serializedSize == 0) { + throw std::runtime_error("Serialized size should have been set to a non-zero value."); + } + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + std::vector deserialized(serializedSize); + communicator->bootstrapper()->recv(deserialized.data(), serializedSize, i, 0); + auto remote = mscclpp::RegisteredMemory::deserialize(deserialized); + remoteMemory[i] = remote; + } + } +} + +void make_connections(std::unique_ptr& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map>& connections){ + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + if (i / nRanksPerNode == rank / nRanksPerNode) { + connections[i] = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); + } else { + connections[i] = communicator->connect(i, 0, myIbDevice); + } + } + } + communicator->connectionSetup(); +} + +void write_remote(int rank, int worldSize, std::unordered_map>& connections, std::unordered_map& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int writeSize){ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + auto& conn = connections.at(i); + auto& peerMemory = remoteRegisteredMemories.at(i); + // printf("write to rank: %d, rank is %d\n", peerMemory.rank(), rank); + conn->write(peerMemory, rank * writeSize, registeredMemory, rank * writeSize, writeSize); + conn->flush(); + } + } + +} + void test_communicator(int rank, int worldSize, int nranksPerNode) { auto bootstrap = std::make_shared(rank, worldSize); @@ -32,104 +81,90 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); bootstrap->initialize(id); - auto communicator = std::make_shared(bootstrap); + auto communicator = std::make_unique(bootstrap); if (bootstrap->getRank() == 0) std::cout << "Communicator initialization passed" << std::endl; std::unordered_map> connections; auto myIbDevice = findIb(rank % nranksPerNode); - for (int i = 0; i < worldSize; i++) { - if (i != rank) { - std::shared_ptr conn; - if (i / nranksPerNode == rank / nranksPerNode) { - conn = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); - } else { - conn = communicator->connect(i, 0, myIbDevice); - } - connections[i] = conn; - } - } - communicator->connectionSetup(); + make_connections(communicator, rank, worldSize, nranksPerNode, myIbDevice, connections); if (bootstrap->getRank() == 0) std::cout << "Connection setup passed" << std::endl; - int* devicePtr; - int size = 1024; - CUDATHROW(cudaMalloc(&devicePtr, size)); - auto registeredMemory = communicator->registerMemory(devicePtr, size, mscclpp::Transport::CudaIpc | myIbDevice); - - for (int i = 0; i < worldSize; i++) { - if (i != rank){ - auto serialized = registeredMemory.serialize(); - int serializedSize = serialized.size(); - bootstrap->send(&serializedSize, sizeof(int), i, 0); - bootstrap->send(serialized.data(), serializedSize, i, 1); - } + int numBuffers = 1000; + std::vector devicePtr(numBuffers); + int deviceBufferSize = 1024*1024; + + std::vector localMemory(numBuffers); + std::vector> remoteMemory(numBuffers); + + for (int n = 0; n < numBuffers; n++) { + if (n % 100 == 0) + std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl; + CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize)); + register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]); } - std::unordered_map registeredMemories; - for (int i = 0; i < worldSize; i++) { - if (i != rank){ - int deserializedSize; - bootstrap->recv(&deserializedSize, sizeof(int), i, 0); - std::vector deserialized(deserializedSize); - bootstrap->recv(deserialized.data(), deserializedSize, i, 1); - auto deserializedRegisteredMemory = mscclpp::RegisteredMemory::deserialize(deserialized); - registeredMemories.insert({deserializedRegisteredMemory.rank(), deserializedRegisteredMemory}); - } - } - bootstrap->barrier(); if (bootstrap->getRank() == 0) - std::cout << "Memory registration passed" << std::endl; - - assert((size / sizeof(int)) % worldSize == 0); - size_t writeSize = size / worldSize; - size_t dataCount = size / sizeof(int); - // std::vector hostBuffer(dataCount, 0); - std::shared_ptr hostBuffer(new int[dataCount]); - for (int i = 0; i < dataCount; i++) { - hostBuffer[i] = rank; + std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + + assert((deviceBufferSize / sizeof(int)) % worldSize == 0); + size_t writeSize = deviceBufferSize / worldSize; + size_t dataCount = deviceBufferSize / sizeof(int); + for (int n = 0; n < numBuffers; n++){ + std::vector hostBuffer(dataCount, 0); + for (int i = 0; i < dataCount; i++) { + hostBuffer[i] = rank + n * worldSize; + } + CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), deviceBufferSize, cudaMemcpyHostToDevice)); } - CUDATHROW(cudaMemcpy(devicePtr, hostBuffer.get(), size, cudaMemcpyHostToDevice)); CUDATHROW(cudaDeviceSynchronize()); bootstrap->barrier(); - for (int i = 0; i < worldSize; i++) { - if (i != rank) { - auto& conn = connections.at(i); - auto& peerMemory = registeredMemories.at(i); - // printf("write to rank: %d, rank is %d\n", peerMemory.rank(), rank); - conn->write(peerMemory, rank * writeSize, registeredMemory, rank * writeSize, writeSize); - conn->flush(); - } + if (bootstrap->getRank() == 0) + std::cout << "CUDA memory initialization passed" << std::endl; + + for (int n = 0; n < numBuffers; n++){ + write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], writeSize); } bootstrap->barrier(); - // polling until it becomes ready - bool ready = false; - int niter = 0; - do { - ready = true; - CUDATHROW(cudaMemcpy(hostBuffer.get(), devicePtr, size, cudaMemcpyDeviceToHost)); - size_t dataPerRank = writeSize / sizeof(int); - for (int i = 0; i < dataCount; i++) { - if (hostBuffer[i] != i / dataPerRank) { - ready = false; + if (bootstrap->getRank() == 0) + std::cout << "RDMA write for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + for (int n = 0; n < numBuffers; n++){ + // polling until it becomes ready + bool ready = false; + int niter = 0; + std::vector hostBuffer(dataCount, 0); + do { + ready = true; + CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], deviceBufferSize, cudaMemcpyDeviceToHost)); + for (int i = 0; i < worldSize; i++) { + for (int j = i*writeSize/sizeof(int); j < (i+1)*writeSize/sizeof(int); j++) { + if (hostBuffer[j] != i + n * worldSize) { + ready = false; + } + } } - } - if (niter == 10000){ - throw std::runtime_error("Polling is stuck."); - } - niter++; - } while (!ready); + if (niter == 10000){ + throw std::runtime_error("Polling is stuck."); + } + niter++; + } while (!ready); + } bootstrap->barrier(); if (bootstrap->getRank() == 0) - std::cout << "Connection write passed" << std::endl; + std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl; - CUDATHROW(cudaFree(devicePtr)); if (bootstrap->getRank() == 0) std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; + + for (int n = 0; n < numBuffers; n++){ + CUDATHROW(cudaFree(devicePtr[n])); + } } int main(int argc, char** argv) From 6aa023ed1e205934a7f450a15b1a8d97a81a7e68 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 2 May 2023 03:28:09 +0000 Subject: [PATCH 35/54] moving serializer outside --- tests/communicator_test_cpp.cc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 6f7aa3e16..7c6423b46 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -25,17 +25,13 @@ mscclpp::Transport findIb(int localRank) void register_all_memories(std::unique_ptr& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemory){ localMemory = communicator->registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice); - int serializedSize = 0; + auto serialized = localMemory.serialize(); + int serializedSize = serialized.size(); for (int i = 0; i < worldSize; i++) { if (i != rank){ - auto serialized = localMemory.serialize(); - serializedSize = serialized.size(); communicator->bootstrapper()->send(serialized.data(), serializedSize, i, 0); } } - if (serializedSize == 0) { - throw std::runtime_error("Serialized size should have been set to a non-zero value."); - } for (int i = 0; i < worldSize; i++) { if (i != rank){ std::vector deserialized(serializedSize); From fe2b778abcb6a9f181a509033ad0ffb0115fb0c1 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 2 May 2023 03:50:57 +0000 Subject: [PATCH 36/54] flushing the full cq --- src/connection.cc | 12 +++++------- src/epoch.cc | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index e0c524195..e1b64072d 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -126,18 +126,16 @@ void IBConnection::flush() while (isWaiting) { int wcNum = qp->pollCq(); if (wcNum < 0) { - WARN("pollCq failed: errno %d", errno); - continue; + throw std::runtime_error("pollCq failed: error no " + std::to_string(errno)); } + isWaiting = false; for (int i = 0; i < wcNum; ++i) { const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { - WARN("wc status %d", wc->status); - continue; + throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status)); } - if (wc->opcode == IBV_WC_RDMA_WRITE) { - isWaiting = false; - break; + if (wc->opcode != IBV_WC_RDMA_WRITE) { + isWaiting = true; } } } diff --git a/src/epoch.cc b/src/epoch.cc index 7bcab9c89..3d17c5a1c 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -21,7 +21,7 @@ Epoch::~Epoch() { } void Epoch::signal() { - connection_->write(localEpochIdsRegMem_, offsetof(EpochIds, outbound_), remoteEpochIdsRegMem_, offsetof(EpochIds, inboundReplica_), sizeof(device_.epochIds_)); + connection_->write(remoteEpochIdsRegMem_, offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); } } // namespace mscclpp \ No newline at end of file From 358c3d62b818fc8d146986a772879c33e6fd9bb8 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 2 May 2023 20:06:30 +0000 Subject: [PATCH 37/54] Generalize connectionSetup() into setup() --- src/communicator.cc | 73 ++++++++++++++++++++++++++++----- src/connection.cc | 7 +--- src/epoch.cc | 10 ++--- src/include/communicator.hpp | 5 ++- src/include/connection.hpp | 9 +--- src/include/epoch.hpp | 6 +-- src/include/host_connection.hpp | 23 ----------- src/include/mscclpp.hpp | 41 ++++++++++++++++-- tests/allgather_test_cpp.cu | 12 +++--- tests/communicator_test_cpp.cc | 2 +- 10 files changed, 117 insertions(+), 71 deletions(-) delete mode 100644 src/include/host_connection.hpp diff --git a/src/communicator.cc b/src/communicator.cc index 21faeaee2..7af88c738 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -23,17 +23,17 @@ Communicator::Impl::Impl(std::shared_ptr bootstrap) : bootstrap_( Communicator::Impl::~Impl() { - ibContexts.clear(); + ibContexts_.clear(); } IbCtx* Communicator::Impl::getIbContext(Transport ibTransport) { // Find IB context or create it - auto it = ibContexts.find(ibTransport); - if (it == ibContexts.end()) { + auto it = ibContexts_.find(ibTransport); + if (it == ibContexts_.end()) { auto ibDev = getIBDeviceName(ibTransport); - ibContexts[ibTransport] = std::make_unique(ibDev); - return ibContexts[ibTransport].get(); + ibContexts_[ibTransport] = std::make_unique(ibDev); + return ibContexts_[ibTransport].get(); } else { return it->second.get(); } @@ -57,6 +57,50 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t std::make_shared(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl)); } +struct MemorySender : public Setuppable +{ + MemorySender(RegisteredMemory memory, int remoteRank, int tag) + : memory_(memory), remoteRank_(remoteRank), tag_(tag) {} + + void beginSetup(std::shared_ptr bootstrap) override + { + bootstrap->send(memory_.serialize(), remoteRank_, tag_); + } + + RegisteredMemory memory_; + int remoteRank_; + int tag_; +}; + +void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) +{ + addSetup(std::make_shared(memory, remoteRank, tag)); +} + +struct MemoryReceiver : public Setuppable +{ + MemoryReceiver(int remoteRank, int tag) + : remoteRank_(remoteRank), tag_(tag) {} + + void endSetup(std::shared_ptr bootstrap) override + { + std::vector data; + bootstrap->recv(data, remoteRank_, tag_); + memoryPromise_.set_value(RegisteredMemory::deserialize(data)); + } + + std::promise memoryPromise_; + int remoteRank_; + int tag_; +}; + +NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) +{ + auto memoryReceiver = std::make_shared(remoteRank, tag); + addSetup(memoryReceiver); + return memoryReceiver->memoryPromise_.get_future(); +} + MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) { std::shared_ptr conn; @@ -84,18 +128,25 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank } else { throw std::runtime_error("Unsupported transport"); } - pimpl->connections.push_back(conn); + pimpl->connections_.push_back(conn); + addSetup(conn); return conn; } -MSCCLPP_API_CPP void Communicator::connectionSetup() +MSCCLPP_API_CPP void Communicator::addSetup(std::shared_ptr setuppable) +{ + pimpl->toSetup_.push_back(setuppable); +} + +MSCCLPP_API_CPP void Communicator::setup() { - for (auto& conn : pimpl->connections) { - conn->startSetup(pimpl->bootstrap_); + for (auto& setuppable : pimpl->toSetup_) { + setuppable->beginSetup(pimpl->bootstrap_); } - for (auto& conn : pimpl->connections) { - conn->endSetup(pimpl->bootstrap_); + for (auto& setuppable : pimpl->toSetup_) { + setuppable->endSetup(pimpl->bootstrap_); } + pimpl->toSetup_.clear(); } } // namespace mscclpp diff --git a/src/connection.cc b/src/connection.cc index e1b64072d..f1ab06f8b 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -80,11 +80,6 @@ IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communi qp = commImpl.getIbContext(transport)->createQp(); } -IBConnection::~IBConnection() -{ - // TODO: Destroy QP? -} - Transport IBConnection::transport() { return transport_; @@ -142,7 +137,7 @@ void IBConnection::flush() // npkitCollectExitEvents(conn, NPKIT_EVENT_IB_SEND_EXIT); } -void IBConnection::startSetup(std::shared_ptr bootstrap) +void IBConnection::beginSetup(std::shared_ptr bootstrap) { std::vector ibQpTransport; std::copy_n(reinterpret_cast(&qp->getInfo()), sizeof(qp->getInfo()), std::back_inserter(ibQpTransport)); diff --git a/src/epoch.cc b/src/epoch.cc index 3d17c5a1c..a14191fd6 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -9,10 +9,8 @@ Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1)); localEpochIdsRegMem_ = communicator.registerMemory(device_.epochIds_, sizeof(device_.epochIds_), connection->transport()); - communicator.bootstrapper()->send(localEpochIdsRegMem_.serialize(), connection->remoteRank(), connection->tag()); - std::vector serializedRemoteEpochIds; - communicator.bootstrapper()->recv(serializedRemoteEpochIds, connection->remoteRank(), connection->tag()); - remoteEpochIdsRegMem_ = RegisteredMemory::deserialize(serializedRemoteEpochIds); + communicator.sendMemoryOnSetup(localEpochIdsRegMem_, connection->remoteRank(), connection->tag()); + remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag()); } Epoch::~Epoch() { @@ -21,7 +19,7 @@ Epoch::~Epoch() { } void Epoch::signal() { - connection_->write(remoteEpochIdsRegMem_, offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); + connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); } -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index b9b28f896..32fb6e302 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -15,8 +15,9 @@ class ConnectionBase; struct Communicator::Impl { mscclppComm_t comm; - std::vector> connections; - std::unordered_map> ibContexts; + std::vector> connections_; + std::vector> toSetup_; + std::unordered_map> ibContexts_; std::shared_ptr bootstrap_; std::vector rankToHash_; diff --git a/src/include/connection.hpp b/src/include/connection.hpp index b28b58908..b380dbfd6 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -10,7 +10,7 @@ namespace mscclpp { // TODO: Add functionality to these classes for Communicator to do connectionSetup -class ConnectionBase : public Connection +class ConnectionBase : public Connection, public Setuppable { int remoteRank_; int tag_; @@ -19,9 +19,6 @@ class ConnectionBase : public Connection int remoteRank() override; int tag() override; - - virtual void startSetup(std::shared_ptr){}; - virtual void endSetup(std::shared_ptr){}; }; class CudaIpcConnection : public ConnectionBase @@ -52,8 +49,6 @@ class IBConnection : public ConnectionBase public: IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); - ~IBConnection(); - Transport transport() override; Transport remoteTransport() override; @@ -63,7 +58,7 @@ class IBConnection : public ConnectionBase void flush() override; - void startSetup(std::shared_ptr bootstrap) override; + void beginSetup(std::shared_ptr bootstrap) override; void endSetup(std::shared_ptr bootstrap) override; }; diff --git a/src/include/epoch.hpp b/src/include/epoch.hpp index 2c6e3296d..742db85c2 100644 --- a/src/include/epoch.hpp +++ b/src/include/epoch.hpp @@ -35,7 +35,7 @@ class Epoch std::shared_ptr connection_; DeviceEpoch device_; RegisteredMemory localEpochIdsRegMem_; - RegisteredMemory remoteEpochIdsRegMem_; + NonblockingFuture remoteEpochIdsRegMem_; public: Epoch(Communicator& communicator, std::shared_ptr connection); @@ -43,9 +43,7 @@ class Epoch void signal(); - DeviceEpoch deviceEpoch() { - return device_; - } + DeviceEpoch deviceEpoch() { return device_; } }; } // namespace mscclpp diff --git a/src/include/host_connection.hpp b/src/include/host_connection.hpp deleted file mode 100644 index 8ac5d9f17..000000000 --- a/src/include/host_connection.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef MSCCLPP_HOST_CONNECTION_HPP_ -#define MSCCLPP_HOST_CONNECTION_HPP_ - -#include "comm.h" -#include "mscclpp.h" -#include "mscclpp.hpp" - -namespace mscclpp { - -struct HostConnection::Impl -{ - Communicator* comm; - mscclppConn* conn; - mscclppHostConn_t* hostConn; - - Impl(Communicator* comm, mscclppConn* conn); - - ~Impl(); -}; - -} // namespace mscclpp - -#endif // MSCCLPP_HOST_CONNECTION_HPP_ \ No newline at end of file diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index fde631801..b4111da8e 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace mscclpp { @@ -277,6 +278,33 @@ class Connection static std::shared_ptr getRegisteredMemoryImpl(RegisteredMemory&); }; +struct Setuppable +{ + virtual void beginSetup(std::shared_ptr) {} + virtual void endSetup(std::shared_ptr) {} +}; + +template +class NonblockingFuture +{ + std::future future; +public: + NonblockingFuture() = default; + NonblockingFuture(std::future&& future) : future(std::move(future)) {} + + bool ready() const + { + return future.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + } + + T get() + { + if (!ready()) + throw std::runtime_error("NonblockingFuture::get() called before ready"); + return future.get(); + } +}; + class Communicator { public: @@ -301,6 +329,10 @@ class Communicator * Returns: a handle to the buffer */ RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); + + void sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag); + + NonblockingFuture recvMemoryOnSetup(int remoteRank, int tag); /* Connect to a remote rank. This function only prepares metadata for connection. The actual connection * is made by a following call of mscclppConnectionSetup(). Note that this function is two-way and a connection @@ -318,10 +350,11 @@ class Communicator */ std::shared_ptr connect(int remoteRank, int tag, Transport transport); - /* Establish all connections declared by connect(). This function must be called after all connect() - * calls are made. This function ensures that all remote ranks are ready to communicate when it returns. - */ - void connectionSetup(); + /* Add a custom Setuppable object to a list of objects to be setup later, when setup() is called. */ + void addSetup(std::shared_ptr setuppable); + + /* Setup all objects that have registered for setup. This includes any connections created by connect(). */ + void setup(); struct Impl; diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index 8fb54733b..791e2ca90 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -224,7 +224,6 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); - mscclpp::channel::DeviceChannelService channelService; for (int r = 0; r < world_size; ++r) { if (r == rank) @@ -238,14 +237,13 @@ void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& co transportType = ibTransport; } // Connect with all other ranks - auto conn = comm.connect(r, 0, transportType); - channelService.addChannel(conn); - // TODO: WIP - hostConn->registerBuffer(data_d, dataSize); - hostConns.push_back(hostConn); + auto connId = channelService.addChannel(comm.connect(r, 0, transportType)); + auto memoryId = channelService.addMemory(comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport)); } - comm.connectionSetup(); + comm.setup(); + + mscclpp::channel::DeviceChannelService channelService; std::vector devConns; std::transform( diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index 7c6423b46..c922eaae9 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -52,7 +52,7 @@ void make_connections(std::unique_ptr& communicator, int } } } - communicator->connectionSetup(); + communicator->setup(); } void write_remote(int rank, int worldSize, std::unordered_map>& connections, std::unordered_map& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int writeSize){ From c7b7d20d850d6c3f531707130bb4da36ce5276fd Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 2 May 2023 20:35:16 +0000 Subject: [PATCH 38/54] Export epoch header --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 782129c09..e8c5bb256 100644 --- a/Makefile +++ b/Makefile @@ -135,7 +135,7 @@ HEADERS := $(wildcard src/include/*.h) CPPSOURCES := $(shell find ./ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)' -not -path "./build/*" -not -path "./python/*") PYTHONCPPSOURCES := $(shell find ./python/src/ -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)') -INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp +INCEXPORTS := mscclpp.h mscclppfifo.h mscclpp.hpp mscclppfifo.hpp epoch.hpp INCTARGETS := $(INCEXPORTS:%=$(BUILDDIR)/$(INCDIR)/%) LIBNAME := libmscclpp.so From 66ce01baf3ac14aba84799bc0ee135410015305e Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 2 May 2023 20:46:30 +0000 Subject: [PATCH 39/54] Make NonblockingFuture copyable --- src/include/mscclpp.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index b4111da8e..5186fbc2c 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -287,10 +287,11 @@ struct Setuppable template class NonblockingFuture { - std::future future; + std::shared_future future; public: NonblockingFuture() = default; - NonblockingFuture(std::future&& future) : future(std::move(future)) {} + NonblockingFuture(std::shared_future&& future) : future(std::move(future)) {} + NonblockingFuture(const NonblockingFuture&) = default; bool ready() const { From c44b48b361e0e36154c50ca10cfc6c42b715caad Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 2 May 2023 21:38:26 +0000 Subject: [PATCH 40/54] Epoch non-copyable --- src/include/epoch.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/include/epoch.hpp b/src/include/epoch.hpp index 742db85c2..ffd7464dc 100644 --- a/src/include/epoch.hpp +++ b/src/include/epoch.hpp @@ -39,6 +39,7 @@ class Epoch public: Epoch(Communicator& communicator, std::shared_ptr connection); + Epoch(const Epoch&) = delete; ~Epoch(); void signal(); From a4e6ffe2bc5f272e132705d0ad73001cd0921ef3 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 2 May 2023 21:39:43 +0000 Subject: [PATCH 41/54] epoch creation --- src/communicator.cc | 2 +- src/epoch.cc | 11 ++++++----- src/include/checks.hpp | 2 +- src/include/mscclpp.hpp | 7 ++++--- tests/communicator_test_cpp.cc | 12 +++++++++++- 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 7af88c738..2507c175a 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -98,7 +98,7 @@ NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRa { auto memoryReceiver = std::make_shared(remoteRank, tag); addSetup(memoryReceiver); - return memoryReceiver->memoryPromise_.get_future(); + return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) diff --git a/src/epoch.cc b/src/epoch.cc index a14191fd6..9263fd1ca 100644 --- a/src/epoch.cc +++ b/src/epoch.cc @@ -1,10 +1,11 @@ #include "epoch.hpp" #include "checks.hpp" #include "alloc.h" +#include "api.h" namespace mscclpp { -Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) : connection_(connection) { +MSCCLPP_API_CPP Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) : connection_(connection) { MSCCLPPTHROW(mscclppCudaCalloc(&device_.epochIds_, 1)); MSCCLPPTHROW(mscclppCudaCalloc(&device_.expectedInboundEpochId_, 1)); @@ -13,12 +14,12 @@ Epoch::Epoch(Communicator& communicator, std::shared_ptr connection) remoteEpochIdsRegMem_ = communicator.recvMemoryOnSetup(connection->remoteRank(), connection->tag()); } -Epoch::~Epoch() { - MSCCLPPTHROW(mscclppCudaFree(&device_.epochIds_)); - MSCCLPPTHROW(mscclppCudaFree(&device_.expectedInboundEpochId_)); +MSCCLPP_API_CPP Epoch::~Epoch() { + mscclppCudaFree(device_.epochIds_); + mscclppCudaFree(device_.expectedInboundEpochId_); } -void Epoch::signal() { +MSCCLPP_API_CPP void Epoch::signal() { connection_->write(remoteEpochIdsRegMem_.get(), offsetof(EpochIds, inboundReplica_), localEpochIdsRegMem_, offsetof(EpochIds, outbound_), sizeof(device_.epochIds_)); } diff --git a/src/include/checks.hpp b/src/include/checks.hpp index 69b222ee1..6473c92fa 100644 --- a/src/include/checks.hpp +++ b/src/include/checks.hpp @@ -17,7 +17,7 @@ if (res != mscclppSuccess && res != mscclppInProgress) { \ throw std::runtime_error(std::string("Call to " #call " failed with error code ") + mscclppGetErrorString(res)); \ } \ - } while (0); + } while (false) #define CUDATHROW(cmd) \ do { \ diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 5186fbc2c..4c26131c4 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -35,15 +35,16 @@ class BaseBootstrap // TODO: move implementations of these helpers out of this header void send(const std::vector& data, int peer, int tag) { - send((void*)data.size(), sizeof(size_t), peer, tag); - send((void*)data.data(), data.size(), peer, tag); + size_t size = data.size(); + send((void*)&size, sizeof(size_t), peer, tag); + send((void*)data.data(), data.size(), peer, tag+1); } void recv(std::vector& data, int peer, int tag) { size_t size; recv((void*)&size, sizeof(size_t), peer, tag); data.resize(size); - recv((void*)data.data(), data.size(), peer, tag); + recv((void*)data.data(), data.size(), peer, tag+1); } }; diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc index c922eaae9..29712cd06 100644 --- a/tests/communicator_test_cpp.cc +++ b/tests/communicator_test_cpp.cc @@ -1,4 +1,5 @@ #include "mscclpp.hpp" +#include "epoch.hpp" #include #include @@ -88,7 +89,7 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Connection setup passed" << std::endl; - int numBuffers = 1000; + int numBuffers = 1; std::vector devicePtr(numBuffers); int deviceBufferSize = 1024*1024; @@ -105,6 +106,15 @@ void test_communicator(int rank, int worldSize, int nranksPerNode) if (bootstrap->getRank() == 0) std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + std::vector> epochs; + for (auto entry : connections) { + auto& conn = entry.second; + epochs.emplace_back(std::make_unique(*communicator, conn)); + } + communicator->setup(); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Epochs are created" << std::endl; assert((deviceBufferSize / sizeof(int)) % worldSize == 0); size_t writeSize = deviceBufferSize / worldSize; From fc12947c5b01d397a7d78a27f2aa1b1f0be7c8c7 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 2 May 2023 21:42:25 +0000 Subject: [PATCH 42/54] fixing flush for IB --- src/connection.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index f1ab06f8b..fd7283fcc 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -123,14 +123,13 @@ void IBConnection::flush() if (wcNum < 0) { throw std::runtime_error("pollCq failed: error no " + std::to_string(errno)); } - isWaiting = false; for (int i = 0; i < wcNum; ++i) { const struct ibv_wc* wc = reinterpret_cast(qp->getWc(i)); if (wc->status != IBV_WC_SUCCESS) { throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status)); } - if (wc->opcode != IBV_WC_RDMA_WRITE) { - isWaiting = true; + if (wc->opcode == IBV_WC_RDMA_WRITE) { + isWaiting = false; } } } From 4ba851683274355697503d6d7a13ee2a9178f6fc Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Tue, 2 May 2023 23:14:13 +0000 Subject: [PATCH 43/54] allgather_test_cpp functional again --- Makefile | 6 +-- src/communicator.cc | 4 +- src/fifo.cc | 21 ++++---- src/ib.cc | 7 +-- src/include/channel.hpp | 8 +-- src/proxy_cpp.cc | 1 + tests/allgather_test_cpp.cu | 100 +++++++++++++++++++----------------- 7 files changed, 79 insertions(+), 68 deletions(-) diff --git a/Makefile b/Makefile index e8c5bb256..7b44e154f 100644 --- a/Makefile +++ b/Makefile @@ -120,8 +120,8 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) -LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc epoch.cc) -#LIBSRCS += $(addprefix src/,fifo.cc host_connection.cc proxy_cpp.cc basic_proxy_handler.cc) +LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) +LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif @@ -149,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc communicator_test_cpp.cc) # allgather_test_cpp.cu +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc communicator_test_cpp.cc allgather_test_cpp.cu) TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/src/communicator.cc b/src/communicator.cc index 2507c175a..074d127fd 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -72,7 +72,7 @@ struct MemorySender : public Setuppable int tag_; }; -void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) +MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) { addSetup(std::make_shared(memory, remoteRank, tag)); } @@ -94,7 +94,7 @@ struct MemoryReceiver : public Setuppable int tag_; }; -NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) +MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) { auto memoryReceiver = std::make_shared(remoteRank, tag); addSetup(memoryReceiver); diff --git a/src/fifo.cc b/src/fifo.cc index c2fdd7385..d5d704229 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -1,6 +1,7 @@ #include "alloc.h" #include "checks.hpp" #include "mscclppfifo.hpp" +#include "api.h" #include #include #include @@ -24,7 +25,7 @@ struct HostProxyFifo::Impl cudaStream_t stream; }; -HostProxyFifo::HostProxyFifo() +MSCCLPP_API_CPP HostProxyFifo::HostProxyFifo() { pimpl = std::make_unique(); MSCCLPPTHROW(mscclppCudaCalloc(&pimpl->deviceFifo.head, 1)); @@ -34,27 +35,27 @@ HostProxyFifo::HostProxyFifo() pimpl->hostTail = 0; } -HostProxyFifo::~HostProxyFifo() +MSCCLPP_API_CPP HostProxyFifo::~HostProxyFifo() { - MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.head)); - MSCCLPPTHROW(mscclppCudaHostFree(pimpl->deviceFifo.triggers)); - MSCCLPPTHROW(mscclppCudaFree(pimpl->deviceFifo.tailReplica)); - CUDATHROW(cudaStreamDestroy(pimpl->stream)); + mscclppCudaFree(pimpl->deviceFifo.head); + mscclppCudaHostFree(pimpl->deviceFifo.triggers); + mscclppCudaFree(pimpl->deviceFifo.tailReplica); + cudaStreamDestroy(pimpl->stream); } -void HostProxyFifo::poll(ProxyTrigger* trigger) +MSCCLPP_API_CPP void HostProxyFifo::poll(ProxyTrigger* trigger) { __m128i xmm0 = _mm_load_si128((__m128i*)&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]); _mm_store_si128((__m128i*)trigger, xmm0); } -void HostProxyFifo::pop() +MSCCLPP_API_CPP void HostProxyFifo::pop() { *(volatile uint64_t*)(&pimpl->deviceFifo.triggers[pimpl->hostTail % MSCCLPP_PROXY_FIFO_SIZE]) = 0; (pimpl->hostTail)++; } -void HostProxyFifo::flushTail(bool sync) +MSCCLPP_API_CPP void HostProxyFifo::flushTail(bool sync) { // Flush the tail to device memory. This is either triggered every MSCCLPP_PROXY_FIFO_FLUSH_COUNTER to make sure // that the fifo can make progress even if there is no request mscclppSync. However, mscclppSync type is for flush @@ -66,7 +67,7 @@ void HostProxyFifo::flushTail(bool sync) } } -DeviceProxyFifo HostProxyFifo::toDevice() +MSCCLPP_API_CPP DeviceProxyFifo HostProxyFifo::deviceFifo() { return pimpl->deviceFifo; } diff --git a/src/ib.cc b/src/ib.cc index ec7e95f25..7e77b235b 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -11,6 +11,7 @@ #include "debug.h" #include "ib.hpp" #include "mscclpp.hpp" +#include "api.h" #include #include @@ -372,14 +373,14 @@ const std::string& IbCtx::getDevName() const return this->devName; } -int getIBDeviceCount() +MSCCLPP_API_CPP int getIBDeviceCount() { int num; ibv_get_device_list(&num); return num; } -std::string getIBDeviceName(Transport ibTransport) +MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) { int num; struct ibv_device** devices = ibv_get_device_list(&num); @@ -418,7 +419,7 @@ std::string getIBDeviceName(Transport ibTransport) return devices[ibTransportIndex]->name; } -Transport getIBTransportByDeviceName(const std::string& ibDeviceName) +MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName) { int num; struct ibv_device** devices = ibv_get_device_list(&num); diff --git a/src/include/channel.hpp b/src/include/channel.hpp index ace576614..42826f4f8 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -13,7 +13,8 @@ namespace channel { class Channel { public: - Channel(std::shared_ptr connection) : connection_(connection), epoch_(std::make_shared()) {}; + Channel(Communicator& communicator, std::shared_ptr connection) + : connection_(connection), epoch_(std::make_shared(communicator, connection)) {}; Connection& connection() { return *connection_; } Epoch& epoch() { return *epoch_; } @@ -176,10 +177,10 @@ inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService class DeviceChannelService { public: - DeviceChannelService() : proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {} + DeviceChannelService(Communicator& communicator) : communicator_(communicator), proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {} ChannelId addChannel(std::shared_ptr connection) { - channels_.push_back(Channel(connection)); + channels_.push_back(Channel(communicator_, connection)); return channels_.size() - 1; } @@ -195,6 +196,7 @@ class DeviceChannelService { void stopProxy() { proxy_.stop(); } private: + Communicator& communicator_; std::vector channels_; std::vector memories_; Proxy proxy_; diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc index b55d6995f..2fb8c2b0d 100644 --- a/src/proxy_cpp.cc +++ b/src/proxy_cpp.cc @@ -1,3 +1,4 @@ +#include "proxy.hpp" #include "api.h" #include "mscclpp.hpp" #include "utils.h" diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index 791e2ca90..34050814e 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -49,9 +49,9 @@ static double getTime(void) return (tspec.tv_nsec / 1.0e9) + tspec.tv_sec; } -__constant__ mscclpp::channel::SimpleDeviceConnection constDevConns[16]; +__constant__ mscclpp::channel::SimpleDeviceChannel constDevChans[16]; -__device__ void allgather0(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int remoteRank, +__device__ void allgather0(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int remoteRank, size_t nelemsPerGPU) { // this allgather is really simple and implemented as an alltoall @@ -59,19 +59,19 @@ __device__ void allgather0(mscclpp::channel::SimpleDeviceConnection devConn, int // this thread's role is a sender role // put your data asynchronously if ((threadIdx.x % 32) == 0) - devConn.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); // make sure everyone is put their data before some thread randomly blocks everyone else in signal __syncthreads(); // push with flag and sync to make sure the data is received if ((threadIdx.x % 32) == 0) - devConn.flush(); + devChan.flush(); // this thread's role is a receiver role. wait on the semaphore to make sure the data is ready if ((threadIdx.x % 32) == 0) - devConn.wait(); + devChan.wait(); } -__device__ void localAllGather(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, +__device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, int remoteRank, uint64_t offset, uint64_t size) { // this allgather algorithm works as follows: @@ -84,25 +84,25 @@ __device__ void localAllGather(mscclpp::channel::SimpleDeviceConnection devConn, if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { // put your data to GPU (rank+i) % nranksPerNode and signal in one call if ((threadIdx.x % 32) == 0) - devConn.putWithSignalAndFlush(offset, size); + devChan.putWithSignalAndFlush(offset, size); } // wait for the data from GPU (rank-i) % nranksPerNode to arrive if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) { if ((threadIdx.x % 32) == 0) - devConn.wait(); + devChan.wait(); } asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory"); } } -__device__ void allgather1(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, +__device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { - localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); } -__device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int rank, int world_size, int nranksPerNode, +__device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, int remoteRank, size_t nelemsPerGPU) { // this allgather is a pipelined and hierarchical one and only works for two nodes @@ -120,17 +120,17 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int // Step 1 // local allgather if (remoteRank / nranksPerNode == rank / nranksPerNode) { - localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); } // cross-node exchange if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devConn.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), + devChan.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) - devConn.wait(); + devChan.wait(); } __syncthreads(); @@ -139,7 +139,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int // local allgather int otherNghr = (rank + nranksPerNode) % world_size; if (remoteRank / nranksPerNode == rank / nranksPerNode) { - localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int), + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, otherNghr * nelemsPerGPU * sizeof(int), (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); } @@ -147,11 +147,11 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devConn.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * + devChan.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), nelemsPerGPU / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) - devConn.wait(); + devChan.wait(); } __syncthreads(); @@ -159,7 +159,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int // Step 3 // local allgather if (remoteRank / nranksPerNode == rank / nranksPerNode) { - localAllGather(devConn, rank, world_size, nranksPerNode, remoteRank, + localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, (otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), nelemsPerGPU / pipelineSize * sizeof(int)); } @@ -167,18 +167,18 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceConnection devConn, int __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel) { - // find the mapping between remoteRank and devConns + // find the mapping between remoteRank and devChans int warpId = threadIdx.x / 32; int remoteRank = (warpId < rank) ? warpId : warpId + 1; // Each warp is responsible for one of the remote ranks - mscclpp::channel::SimpleDeviceConnection devConn = constDevConns[warpId]; + mscclpp::channel::SimpleDeviceChannel devChan = constDevChans[warpId]; if (kernel == 0) - allgather0(devConn, rank, world_size, remoteRank, nelemsPerGPU); + allgather0(devChan, rank, world_size, remoteRank, nelemsPerGPU); else if (kernel == 1) - allgather1(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); + allgather1(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); else if (kernel == 2) - allgather2(devConn, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); + allgather2(devChan, rank, world_size, nranksPerNode, remoteRank, nelemsPerGPU); } int rankToLocalRank(int rank) @@ -218,41 +218,44 @@ void initializeAndAllocateAllGatherData(int rank, int world_size, size_t dataSiz CUDACHECK(cudaMemcpy(*data_d, *data_h, dataSize, cudaMemcpyHostToDevice)); } -void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, int* data_d, size_t dataSize) +void setupMscclppConnections(int rank, int world_size, mscclpp::Communicator& comm, mscclpp::channel::DeviceChannelService& channelService, int* data_d, size_t dataSize) { int thisNode = rankToNode(rank); int cudaNum = rankToLocalRank(rank); std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum); mscclpp::Transport ibTransport = mscclpp::getIBTransportByDeviceName(ibDevStr); + std::vector channelIds; + std::vector localMemories; + std::vector> remoteMemories; for (int r = 0; r < world_size; ++r) { if (r == rank) continue; mscclpp::Transport transport; - const char* ibDev = ibDevStr.c_str(); if (rankToNode(r) == thisNode) { - ibDev = NULL; - transportType = mscclpp::Transport::CudaIpc; + transport = mscclpp::Transport::CudaIpc; } else { - transportType = ibTransport; + transport = ibTransport; } // Connect with all other ranks - auto connId = channelService.addChannel(comm.connect(r, 0, transportType)); - auto memoryId = channelService.addMemory(comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport)); + channelIds.push_back(channelService.addChannel(comm.connect(r, 0, transport))); + auto memory = comm.registerMemory(data_d, dataSize, mscclpp::Transport::CudaIpc | ibTransport); + localMemories.push_back(memory); + comm.sendMemoryOnSetup(memory, r, 0); + remoteMemories.push_back(comm.recvMemoryOnSetup(r, 0)); } comm.setup(); - mscclpp::channel::DeviceChannelService channelService; - - std::vector devConns; - std::transform( - hostConns.begin(), hostConns.end(), std::back_inserter(devConns), - [](std::shared_ptr& hostConn) { return mscclpp::SimpleDeviceConnection(*hostConn); }); + std::vector devChannels; + for (size_t i = 0; i < channelIds.size(); ++i) { + devChannels.push_back(mscclpp::channel::SimpleDeviceChannel(channelService.deviceChannel(channelIds[i]), + channelService.addMemory(remoteMemories[i].get()), channelService.addMemory(localMemories[i]))); + } - assert(devConns.size() < sizeof(constDevConns) / sizeof(mscclpp::SimpleDeviceConnection)); + assert(devChannels.size() < sizeof(constDevChans) / sizeof(mscclpp::channel::SimpleDeviceChannel)); CUDACHECK( - cudaMemcpyToSymbol(constDevConns, devConns.data(), sizeof(mscclpp::SimpleDeviceConnection) * devConns.size())); + cudaMemcpyToSymbol(constDevChans, devChannels.data(), sizeof(mscclpp::channel::SimpleDeviceChannel) * devChannels.size())); } void printUsage(const char* prog, bool isMpi) @@ -405,7 +408,10 @@ int main(int argc, const char* argv[]) try { if (rank == 0) printf("Initializing MSCCL++\n"); - mscclpp::Communicator comm(world_size, ip_port, rank); + auto bootstrapper = std::make_shared(rank, world_size); + bootstrapper->initialize(ip_port); + mscclpp::Communicator comm(bootstrapper); + mscclpp::channel::DeviceChannelService channelService(comm); if (rank == 0) printf("Initializing data for allgather test\n"); @@ -413,11 +419,11 @@ int main(int argc, const char* argv[]) if (rank == 0) printf("Setting up the connection in MSCCL++\n"); - setupMscclppConnections(rank, world_size, comm, data_d, dataSize); + setupMscclppConnections(rank, world_size, comm, channelService, data_d, dataSize); if (rank == 0) printf("Launching MSCCL++ proxy threads\n"); - comm.startProxying(); + channelService.startProxy(); if (rank == 0) printf("Testing the correctness of AllGather implementation\n"); @@ -437,7 +443,7 @@ int main(int argc, const char* argv[]) } int tmp[16]; // A simple barrier - comm.bootstrapAllGather(tmp, sizeof(int)); + bootstrapper->allGather(tmp, sizeof(int)); if (rank == 0) printf("Successfully checked the correctness\n"); @@ -446,12 +452,12 @@ int main(int argc, const char* argv[]) if (rank == 0) printf("Running %d iterations of the kernel without CUDA graph\n", iterwithoutcudagraph); CUDACHECK(cudaStreamSynchronize(stream)); - comm.bootstrapAllGather(tmp, sizeof(int)); + bootstrapper->allGather(tmp, sizeof(int)); for (int i = 0; i < iterwithoutcudagraph; ++i) { kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum); } CUDACHECK(cudaStreamSynchronize(stream)); - comm.bootstrapAllGather(tmp, sizeof(int)); + bootstrapper->allGather(tmp, sizeof(int)); // cudaGraph Capture int cudagraphiter = 10; @@ -480,7 +486,7 @@ int main(int argc, const char* argv[]) if (rank == 0) printf("Running %d iterations of the CUDA graph with %d iterations of the kernel\n", cudagraphlaunch, cudagraphiter); - comm.bootstrapAllGather(tmp, sizeof(int)); + bootstrapper->allGather(tmp, sizeof(int)); double t0, t1, ms, time_in_us; t0 = getTime(); for (int i = 0; i < cudagraphlaunch; ++i) { @@ -493,11 +499,11 @@ int main(int argc, const char* argv[]) time_in_us = ms * 1000. / (float)cudagraphlaunch / (float)cudagraphiter; printf("Rank %d report: size %lu time: %f us/iter algBW %f GBps\n", rank, dataSize, time_in_us, (double)(dataSize) / 1e9 / (time_in_us / 1e6)); - comm.bootstrapAllGather(tmp, sizeof(int)); + bootstrapper->allGather(tmp, sizeof(int)); if (rank == 0) printf("Stopping MSCCL++ proxy threads\n"); - comm.stopProxying(); + channelService.stopProxy(); } catch (std::exception& e) { // todo: throw exceptions in the implementation and process them here From 54d1e1872caf5918a8df2f4793e71f7595c485ab Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Tue, 2 May 2023 23:53:31 +0000 Subject: [PATCH 44/54] testing writes with signal is passing --- Makefile | 2 +- src/communicator.cc | 6 +- src/connection.cc | 8 +- src/include/connection.hpp | 1 + src/include/mscclpp.hpp | 2 +- tests/communicator_test_cpp.cc | 193 ---------------------- tests/communicator_test_cpp.cu | 289 +++++++++++++++++++++++++++++++++ 7 files changed, 299 insertions(+), 202 deletions(-) delete mode 100644 tests/communicator_test_cpp.cc create mode 100644 tests/communicator_test_cpp.cu diff --git a/Makefile b/Makefile index e8c5bb256..cb71ec866 100644 --- a/Makefile +++ b/Makefile @@ -149,7 +149,7 @@ UTOBJTARGETS := $(UTOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) UTBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(UTOBJS)) TESTSDIR := tests -TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu bootstrap_test_cpp.cc communicator_test_cpp.cc) # allgather_test_cpp.cu +TESTSSRCS := $(addprefix $(TESTSDIR)/,bootstrap_test.cc allgather_test_standalone.cu communicator_test_cpp.cu bootstrap_test_cpp.cc) # allgather_test_cpp.cu TESTSOBJS := $(patsubst %.cc,%.o,$(TESTSSRCS)) $(patsubst %.cu,%.o,$(TESTSSRCS)) TESTSOBJTARGETS := $(TESTSOBJS:%=$(BUILDDIR)/$(OBJDIR)/%) TESTSBINS := $(patsubst %.o,$(BUILDDIR)/$(BINDIR)/%,$(TESTSOBJS)) diff --git a/src/communicator.cc b/src/communicator.cc index 2507c175a..1fd641320 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -72,7 +72,7 @@ struct MemorySender : public Setuppable int tag_; }; -void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) +MSCCLPP_API_CPP void Communicator::sendMemoryOnSetup(RegisteredMemory memory, int remoteRank, int tag) { addSetup(std::make_shared(memory, remoteRank, tag)); } @@ -94,14 +94,14 @@ struct MemoryReceiver : public Setuppable int tag_; }; -NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) +MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSetup(int remoteRank, int tag) { auto memoryReceiver = std::make_shared(remoteRank, tag); addSetup(memoryReceiver); return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connect(int remoteRank, int tag, Transport transport) +MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { diff --git a/src/connection.cc b/src/connection.cc index fd7283fcc..66c54f062 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -75,7 +75,7 @@ void CudaIpcConnection::flush() // IBConnection IBConnection::IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl) - : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown) + : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0) { qp = commImpl.getIbContext(transport)->createQp(); } @@ -110,6 +110,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem qp->stageSend(srcMr, dstMrInfo, (uint32_t)size, /*wrId=*/0, /*srcOffset=*/srcOffset, /*dstOffset=*/dstOffset, /*signaled=*/true); + numSignaledSends++; qp->postSend(); INFO(MSCCLPP_NET, "IBConnection write: from %p to %p, size %lu", (uint8_t*)srcMr->getBuff() + srcOffset, (uint8_t*)dstMrInfo.addr + dstOffset, size); // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); @@ -117,8 +118,7 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem void IBConnection::flush() { - bool isWaiting = true; - while (isWaiting) { + while (numSignaledSends) { int wcNum = qp->pollCq(); if (wcNum < 0) { throw std::runtime_error("pollCq failed: error no " + std::to_string(errno)); @@ -129,7 +129,7 @@ void IBConnection::flush() throw std::runtime_error("pollCq failed: status " + std::to_string(wc->status)); } if (wc->opcode == IBV_WC_RDMA_WRITE) { - isWaiting = false; + numSignaledSends--; } } } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index b380dbfd6..8d1dec876 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -45,6 +45,7 @@ class IBConnection : public ConnectionBase Transport transport_; Transport remoteTransport_; IbQp* qp; + int numSignaledSends; public: IBConnection(int remoteRank, int tag, Transport transport, Communicator::Impl& commImpl); diff --git a/src/include/mscclpp.hpp b/src/include/mscclpp.hpp index 4c26131c4..47ca94376 100644 --- a/src/include/mscclpp.hpp +++ b/src/include/mscclpp.hpp @@ -350,7 +350,7 @@ class Communicator * transportType: the type of transport to be used (mscclppTransportP2P or mscclppTransportIB) * ibDev: the name of the IB device to be used. Expects a null for mscclppTransportP2P. */ - std::shared_ptr connect(int remoteRank, int tag, Transport transport); + std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport); /* Add a custom Setuppable object to a list of objects to be setup later, when setup() is called. */ void addSetup(std::shared_ptr setuppable); diff --git a/tests/communicator_test_cpp.cc b/tests/communicator_test_cpp.cc deleted file mode 100644 index 29712cd06..000000000 --- a/tests/communicator_test_cpp.cc +++ /dev/null @@ -1,193 +0,0 @@ -#include "mscclpp.hpp" -#include "epoch.hpp" - -#include -#include -#include -#include -#include -#include - -#define CUDATHROW(cmd) \ - do { \ - cudaError_t err = cmd; \ - if (err != cudaSuccess) { \ - throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \ - } \ - } while (false) - -mscclpp::Transport findIb(int localRank) -{ - mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, - mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, - mscclpp::Transport::IB6, mscclpp::Transport::IB7}; - return IBs[localRank]; -} - -void register_all_memories(std::unique_ptr& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemory){ - localMemory = communicator->registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice); - auto serialized = localMemory.serialize(); - int serializedSize = serialized.size(); - for (int i = 0; i < worldSize; i++) { - if (i != rank){ - communicator->bootstrapper()->send(serialized.data(), serializedSize, i, 0); - } - } - for (int i = 0; i < worldSize; i++) { - if (i != rank){ - std::vector deserialized(serializedSize); - communicator->bootstrapper()->recv(deserialized.data(), serializedSize, i, 0); - auto remote = mscclpp::RegisteredMemory::deserialize(deserialized); - remoteMemory[i] = remote; - } - } -} - -void make_connections(std::unique_ptr& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map>& connections){ - for (int i = 0; i < worldSize; i++) { - if (i != rank){ - if (i / nRanksPerNode == rank / nRanksPerNode) { - connections[i] = communicator->connect(i, 0, mscclpp::Transport::CudaIpc); - } else { - connections[i] = communicator->connect(i, 0, myIbDevice); - } - } - } - communicator->setup(); -} - -void write_remote(int rank, int worldSize, std::unordered_map>& connections, std::unordered_map& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int writeSize){ - for (int i = 0; i < worldSize; i++) { - if (i != rank) { - auto& conn = connections.at(i); - auto& peerMemory = remoteRegisteredMemories.at(i); - // printf("write to rank: %d, rank is %d\n", peerMemory.rank(), rank); - conn->write(peerMemory, rank * writeSize, registeredMemory, rank * writeSize, writeSize); - conn->flush(); - } - } - -} - -void test_communicator(int rank, int worldSize, int nranksPerNode) -{ - auto bootstrap = std::make_shared(rank, worldSize); - mscclpp::UniqueId id; - if (bootstrap->getRank() == 0) - id = bootstrap->createUniqueId(); - MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); - bootstrap->initialize(id); - - auto communicator = std::make_unique(bootstrap); - if (bootstrap->getRank() == 0) - std::cout << "Communicator initialization passed" << std::endl; - - std::unordered_map> connections; - auto myIbDevice = findIb(rank % nranksPerNode); - - make_connections(communicator, rank, worldSize, nranksPerNode, myIbDevice, connections); - if (bootstrap->getRank() == 0) - std::cout << "Connection setup passed" << std::endl; - - int numBuffers = 1; - std::vector devicePtr(numBuffers); - int deviceBufferSize = 1024*1024; - - std::vector localMemory(numBuffers); - std::vector> remoteMemory(numBuffers); - - for (int n = 0; n < numBuffers; n++) { - if (n % 100 == 0) - std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl; - CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize)); - register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]); - } - bootstrap->barrier(); - if (bootstrap->getRank() == 0) - std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; - - std::vector> epochs; - for (auto entry : connections) { - auto& conn = entry.second; - epochs.emplace_back(std::make_unique(*communicator, conn)); - } - communicator->setup(); - bootstrap->barrier(); - if (bootstrap->getRank() == 0) - std::cout << "Epochs are created" << std::endl; - - assert((deviceBufferSize / sizeof(int)) % worldSize == 0); - size_t writeSize = deviceBufferSize / worldSize; - size_t dataCount = deviceBufferSize / sizeof(int); - for (int n = 0; n < numBuffers; n++){ - std::vector hostBuffer(dataCount, 0); - for (int i = 0; i < dataCount; i++) { - hostBuffer[i] = rank + n * worldSize; - } - CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), deviceBufferSize, cudaMemcpyHostToDevice)); - } - CUDATHROW(cudaDeviceSynchronize()); - - bootstrap->barrier(); - if (bootstrap->getRank() == 0) - std::cout << "CUDA memory initialization passed" << std::endl; - - for (int n = 0; n < numBuffers; n++){ - write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], writeSize); - } - bootstrap->barrier(); - if (bootstrap->getRank() == 0) - std::cout << "RDMA write for " << std::to_string(numBuffers) << " buffers passed" << std::endl; - - for (int n = 0; n < numBuffers; n++){ - // polling until it becomes ready - bool ready = false; - int niter = 0; - std::vector hostBuffer(dataCount, 0); - do { - ready = true; - CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], deviceBufferSize, cudaMemcpyDeviceToHost)); - for (int i = 0; i < worldSize; i++) { - for (int j = i*writeSize/sizeof(int); j < (i+1)*writeSize/sizeof(int); j++) { - if (hostBuffer[j] != i + n * worldSize) { - ready = false; - } - } - } - if (niter == 10000){ - throw std::runtime_error("Polling is stuck."); - } - niter++; - } while (!ready); - } - - bootstrap->barrier(); - if (bootstrap->getRank() == 0) - std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl; - - if (bootstrap->getRank() == 0) - std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; - - for (int n = 0; n < numBuffers; n++){ - CUDATHROW(cudaFree(devicePtr[n])); - } -} - -int main(int argc, char** argv) -{ - int rank, worldSize; - MPI_Init(&argc, &argv); - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - MPI_Comm shmcomm; - MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); - int shmWorldSize; - MPI_Comm_size(shmcomm, &shmWorldSize); - int nranksPerNode = shmWorldSize; - MPI_Comm_free(&shmcomm); - - test_communicator(rank, worldSize, nranksPerNode); - - MPI_Finalize(); - return 0; -} \ No newline at end of file diff --git a/tests/communicator_test_cpp.cu b/tests/communicator_test_cpp.cu new file mode 100644 index 000000000..fcdd0f5a3 --- /dev/null +++ b/tests/communicator_test_cpp.cu @@ -0,0 +1,289 @@ +#include "mscclpp.hpp" +#include "epoch.hpp" + +#include +#include +#include +#include +#include +#include + +#define CUDATHROW(cmd) \ + do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + throw std::runtime_error(std::string("Cuda failure '") + cudaGetErrorString(err) + "'"); \ + } \ + } while (false) + +mscclpp::Transport findIb(int localRank) +{ + mscclpp::Transport IBs[] = {mscclpp::Transport::IB0, mscclpp::Transport::IB1, mscclpp::Transport::IB2, + mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5, + mscclpp::Transport::IB6, mscclpp::Transport::IB7}; + return IBs[localRank]; +} + +void register_all_memories(mscclpp::Communicator& communicator, int rank, int worldSize, void* devicePtr, size_t deviceBufferSize, mscclpp::Transport myIbDevice, mscclpp::RegisteredMemory& localMemory, std::unordered_map& remoteMemory){ + localMemory = communicator.registerMemory(devicePtr, deviceBufferSize, mscclpp::Transport::CudaIpc | myIbDevice); + std::unordered_map> futureRemoteMemory; + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + communicator.sendMemoryOnSetup(localMemory, i, 0); + futureRemoteMemory[i] = communicator.recvMemoryOnSetup(i, 0); + } + } + communicator.setup(); + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + remoteMemory[i] = futureRemoteMemory[i].get(); + } + } + + + // auto serialized = localMemory.serialize(); + // int serializedSize = serialized.size(); + // for (int i = 0; i < worldSize; i++) { + // if (i != rank){ + // communicator.bootstrapper()->send(serialized.data(), serializedSize, i, 0); + // } + // } + // for (int i = 0; i < worldSize; i++) { + // if (i != rank){ + // std::vector deserialized(serializedSize); + // communicator.bootstrapper()->recv(deserialized.data(), serializedSize, i, 0); + // auto remote = mscclpp::RegisteredMemory::deserialize(deserialized); + // remoteMemory[i] = remote; + // } + // } +} + +void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map>& connections){ + for (int i = 0; i < worldSize; i++) { + if (i != rank){ + if (i / nRanksPerNode == rank / nRanksPerNode) { + connections[i] = communicator.connectOnSetup(i, 0, mscclpp::Transport::CudaIpc); + } else { + connections[i] = communicator.connectOnSetup(i, 0, myIbDevice); + } + } + } + communicator.setup(); +} + +void write_remote(int rank, int worldSize, std::unordered_map>& connections, + std::unordered_map& remoteRegisteredMemories, mscclpp::RegisteredMemory& registeredMemory, int dataCountPerRank){ + for (int i = 0; i < worldSize; i++) { + if (i != rank) { + auto& conn = connections.at(i); + auto& peerMemory = remoteRegisteredMemories.at(i); + conn->write(peerMemory, rank * dataCountPerRank * sizeof(int), registeredMemory, rank * dataCountPerRank*sizeof(int), dataCountPerRank*sizeof(int)); + conn->flush(); + } + } +} + +void device_buffer_init(int rank, int worldSize, int dataCount, std::vector& devicePtr){ + for (int n = 0; n < (int)devicePtr.size(); n++){ + std::vector hostBuffer(dataCount, 0); + for (int i = 0; i < dataCount; i++) { + hostBuffer[i] = rank + n * worldSize; + } + CUDATHROW(cudaMemcpy(devicePtr[n], hostBuffer.data(), dataCount*sizeof(int), cudaMemcpyHostToDevice)); + } + CUDATHROW(cudaDeviceSynchronize()); +} + +bool test_device_buffer_write_correctness(int worldSize, int dataCount, std::vector& devicePtr){ + for (int n = 0; n < (int)devicePtr.size(); n++){ + std::vector hostBuffer(dataCount, 0); + CUDATHROW(cudaMemcpy(hostBuffer.data(), devicePtr[n], dataCount*sizeof(int), cudaMemcpyDeviceToHost)); + for (int i = 0; i < worldSize; i++) { + for (int j = i*dataCount/worldSize; j < (i+1)*dataCount/worldSize; j++) { + if (hostBuffer[j] != i + n * worldSize) { + return false; + } + } + } + } + return true; +} + +void test_write(int rank, int worldSize, int deviceBufferSize, std::shared_ptr bootstrap, std::unordered_map>& connections, + std::vector>& remoteMemory, std::vector& localMemory, std::vector& devicePtr, int numBuffers){ + + assert((deviceBufferSize / sizeof(int)) % worldSize == 0); + size_t dataCount = deviceBufferSize / sizeof(int); + + device_buffer_init(rank, worldSize, dataCount, devicePtr); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA memory initialization passed" << std::endl; + + for (int n = 0; n < numBuffers; n++){ + write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); + } + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "RDMA write for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + // polling until it becomes ready + bool ready = false; + int niter = 0; + do { + ready = test_device_buffer_write_correctness(worldSize, dataCount, devicePtr); + niter++; + if (niter == 10000){ + throw std::runtime_error("Polling is stuck."); + } + } while (!ready); + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Polling for " << std::to_string(numBuffers) << " buffers passed" << std::endl; +} + +__global__ void increament_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){ + int tid = threadIdx.x; + if (tid != rank && tid < worldSize){ + deviceEpochs[tid].epochIncrement(); + } +} + +__global__ void wait_epochs(mscclpp::DeviceEpoch* deviceEpochs, int rank, int worldSize){ + int tid = threadIdx.x; + if (tid != rank && tid < worldSize){ + deviceEpochs[tid].wait(); + } +} + +void test_write_with_epochs(int rank, int worldSize, int deviceBufferSize, std::shared_ptr bootstrap, std::unordered_map>& connections, + std::vector>& remoteMemory, std::vector& localMemory, std::vector& devicePtr, std::unordered_map> epochs, int numBuffers){ + + assert((deviceBufferSize / sizeof(int)) % worldSize == 0); + size_t dataCount = deviceBufferSize / sizeof(int); + + device_buffer_init(rank, worldSize, dataCount, devicePtr); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA memory initialization passed" << std::endl; + + mscclpp::DeviceEpoch* deviceEpochs; + CUDATHROW(cudaMalloc(&deviceEpochs, sizeof(mscclpp::DeviceEpoch) * worldSize)); + for (int i = 0; i < worldSize; i++){ + if (i != rank){ + mscclpp::DeviceEpoch deviceEpoch = epochs[i]->deviceEpoch(); + CUDATHROW(cudaMemcpy(&deviceEpochs[i], &deviceEpoch, sizeof(mscclpp::DeviceEpoch), cudaMemcpyHostToDevice)); + } + } + CUDATHROW(cudaDeviceSynchronize()); + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "CUDA device epochs are created" << std::endl; + + + for (int n = 0; n < numBuffers; n++){ + write_remote(rank, worldSize, connections, remoteMemory[n], localMemory[n], dataCount / worldSize); + } + + increament_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize); + CUDATHROW(cudaDeviceSynchronize()); + + for (int i = 0; i < worldSize; i++){ + if (i != rank){ + epochs[i]->signal(); + } + } + + wait_epochs<<<1, worldSize>>>(deviceEpochs, rank, worldSize); + CUDATHROW(cudaDeviceSynchronize()); + + if (!test_device_buffer_write_correctness(worldSize, dataCount, devicePtr)){ + throw std::runtime_error("unexpected result."); + } + + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "--- Testing writes with singal for " << std::to_string(numBuffers) << " buffers passed ---" << std::endl; +} + +void test_communicator(int rank, int worldSize, int nranksPerNode) +{ + auto bootstrap = std::make_shared(rank, worldSize); + mscclpp::UniqueId id; + if (bootstrap->getRank() == 0) + id = bootstrap->createUniqueId(); + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD); + bootstrap->initialize(id); + + mscclpp::Communicator communicator(bootstrap); + if (bootstrap->getRank() == 0) + std::cout << "Communicator initialization passed" << std::endl; + + std::unordered_map> connections; + auto myIbDevice = findIb(rank % nranksPerNode); + + make_connections(communicator, rank, worldSize, nranksPerNode, myIbDevice, connections); + if (bootstrap->getRank() == 0) + std::cout << "Connection setup passed" << std::endl; + + int numBuffers = 10; + std::vector devicePtr(numBuffers); + int deviceBufferSize = 1024*1024; + + std::vector localMemory(numBuffers); + std::vector> remoteMemory(numBuffers); + + for (int n = 0; n < numBuffers; n++) { + if (n % 100 == 0) + std::cout << "Registering memory for " << std::to_string(n) << " buffers" << std::endl; + CUDATHROW(cudaMalloc(&devicePtr[n], deviceBufferSize)); + register_all_memories(communicator, rank, worldSize, devicePtr[n], deviceBufferSize, myIbDevice, localMemory[n], remoteMemory[n]); + } + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Memory registration for " << std::to_string(numBuffers) << " buffers passed" << std::endl; + + test_write(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, numBuffers); + if (bootstrap->getRank() == 0) + std::cout << "--- Testing vanialla writes passed ---" << std::endl; + + std::unordered_map> epochs; + for (auto entry : connections) { + auto& conn = entry.second; + epochs.insert({entry.first, std::make_shared(communicator, conn)}); + } + communicator.setup(); + bootstrap->barrier(); + if (bootstrap->getRank() == 0) + std::cout << "Epochs are created" << std::endl; + + test_write_with_epochs(rank, worldSize, deviceBufferSize, bootstrap, connections, remoteMemory, localMemory, devicePtr, epochs, numBuffers); + + if (bootstrap->getRank() == 0) + std::cout << "--- MSCCLPP::Communicator tests passed! ---" << std::endl; + + for (int n = 0; n < numBuffers; n++){ + CUDATHROW(cudaFree(devicePtr[n])); + } +} + +int main(int argc, char** argv) +{ + int rank, worldSize; + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + MPI_Comm shmcomm; + MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &shmcomm); + int shmWorldSize; + MPI_Comm_size(shmcomm, &shmWorldSize); + int nranksPerNode = shmWorldSize; + MPI_Comm_free(&shmcomm); + + test_communicator(rank, worldSize, nranksPerNode); + + MPI_Finalize(); + return 0; +} \ No newline at end of file From 81e7d1b344af413a5b6a73c0ebab754d0dff7bf6 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 3 May 2023 17:11:25 +0000 Subject: [PATCH 45/54] Channels work --- Makefile | 2 +- src/channel.cc | 26 ++++++++++++++++++++ src/connection.cc | 4 ++- src/include/channel.hpp | 6 ++++- src/include/proxy.hpp | 3 +-- src/include/utils.hpp | 54 +++++++++++++++++++++++++++++++++++++++++ src/proxy_cpp.cc | 16 ++++++++---- 7 files changed, 101 insertions(+), 10 deletions(-) create mode 100644 src/channel.cc create mode 100644 src/include/utils.hpp diff --git a/Makefile b/Makefile index 78b993cf0..2b80afb5d 100644 --- a/Makefile +++ b/Makefile @@ -121,7 +121,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) -LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc) +LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif diff --git a/src/channel.cc b/src/channel.cc new file mode 100644 index 000000000..42572390a --- /dev/null +++ b/src/channel.cc @@ -0,0 +1,26 @@ +#include "channel.hpp" +#include "utils.h" +#include "checks.hpp" +#include "api.h" +#include "debug.h" + +namespace mscclpp { +namespace channel { + +MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator), + proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { + int cudaDevice; + CUDATHROW(cudaGetDevice(&cudaDevice)); + MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode)); +} + +MSCCLPP_API_CPP void DeviceChannelService::bindThread() +{ + if (deviceNumaNode >= 0) { + MSCCLPPTHROW(numaBind(deviceNumaNode)); + INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode); + } +} + +} // namespace channel +} // namespace mscclpp \ No newline at end of file diff --git a/src/connection.cc b/src/connection.cc index 66c54f062..0dee770b4 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -4,6 +4,7 @@ #include "infiniband/verbs.h" #include "npkit/npkit.h" #include "registered_memory.hpp" +#include "utils.hpp" namespace mscclpp { @@ -33,7 +34,7 @@ int ConnectionBase::tag() { return tag_; } CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag) { - cudaStreamCreate(&stream); + CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); } CudaIpcConnection::~CudaIpcConnection() @@ -54,6 +55,7 @@ Transport CudaIpcConnection::remoteTransport() void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { + ScopedTimer timer("CudaIpcConnection::write"); validateTransport(dst, remoteTransport()); validateTransport(src, transport()); diff --git a/src/include/channel.hpp b/src/include/channel.hpp index 42826f4f8..eb4bd9e7e 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -5,6 +5,7 @@ #include "mscclpp.hpp" #include "proxy.hpp" #include "mscclppfifo.hpp" +#include "utils.hpp" namespace mscclpp { namespace channel { @@ -177,7 +178,7 @@ inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService class DeviceChannelService { public: - DeviceChannelService(Communicator& communicator) : communicator_(communicator), proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {} + DeviceChannelService(Communicator& communicator); ChannelId addChannel(std::shared_ptr connection) { channels_.push_back(Channel(communicator_, connection)); @@ -200,6 +201,9 @@ class DeviceChannelService { std::vector channels_; std::vector memories_; Proxy proxy_; + int deviceNumaNode; + + void bindThread(); ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) { ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); diff --git a/src/include/proxy.hpp b/src/include/proxy.hpp index f913beac7..51ae47525 100644 --- a/src/include/proxy.hpp +++ b/src/include/proxy.hpp @@ -21,12 +21,11 @@ using ProxyHandler = std::function; class Proxy { public: + Proxy(ProxyHandler handler, std::function threadInit); Proxy(ProxyHandler handler); - ~Proxy(); void start(); - void stop(); HostProxyFifo& fifo(); diff --git a/src/include/utils.hpp b/src/include/utils.hpp new file mode 100644 index 000000000..9abf99944 --- /dev/null +++ b/src/include/utils.hpp @@ -0,0 +1,54 @@ +#ifndef MSCCLPP_UTILS_HPP_ +#define MSCCLPP_UTILS_HPP_ + +#include +#include + +namespace mscclpp { + +struct Timer +{ + std::chrono::steady_clock::time_point start; + + Timer() + { + start = std::chrono::steady_clock::now(); + } + + int64_t elapsed() + { + auto end = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(end - start).count(); + } + + void reset() + { + start = std::chrono::steady_clock::now(); + } + + void print(const char* name) + { + auto end = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start).count(); + printf("%s: %ld us\n", name, elapsed); + } +}; + +struct ScopedTimer +{ + Timer timer; + const char* name; + + ScopedTimer(const char* name) : name(name) + { + } + + ~ScopedTimer() + { + timer.print(name); + } +}; + +} // namespace mscclpp + +#endif // MSCCLPP_UTILS_HPP_ diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc index 2fb8c2b0d..b16268130 100644 --- a/src/proxy_cpp.cc +++ b/src/proxy_cpp.cc @@ -2,6 +2,7 @@ #include "api.h" #include "mscclpp.hpp" #include "utils.h" +#include "utils.hpp" #include #include @@ -14,18 +15,23 @@ const int ProxyFlushPeriod = 4; struct Proxy::Impl { ProxyHandler handler; + std::function threadInit; HostProxyFifo fifo; std::thread service; std::atomic_bool running; - Impl(ProxyHandler handler) : handler(handler), running(false) + Impl(ProxyHandler handler, std::function threadInit) : handler(handler), threadInit(threadInit), running(false) { } }; -MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function threadInit) +{ + pimpl = std::make_unique(handler, threadInit); +} + +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {}) { - pimpl = std::make_unique(handler); } MSCCLPP_API_CPP Proxy::~Proxy() @@ -39,8 +45,8 @@ MSCCLPP_API_CPP void Proxy::start() { pimpl->running = true; pimpl->service = std::thread([this] { - // from this point on, proxy thread will stay close to the device - // PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this + + pimpl->threadInit(); ProxyHandler handler = this->pimpl->handler; HostProxyFifo& fifo = this->pimpl->fifo; From 39666f999ffdbc1d5a1fa6cdaff7494548808fd6 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 3 May 2023 19:20:45 +0000 Subject: [PATCH 46/54] Quick fix --- src/communicator.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/communicator.cc b/src/communicator.cc index 1fd641320..469502b79 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -86,7 +86,9 @@ struct MemoryReceiver : public Setuppable { std::vector data; bootstrap->recv(data, remoteRank_, tag_); - memoryPromise_.set_value(RegisteredMemory::deserialize(data)); + auto memory = RegisteredMemory::deserialize(data); + memory.data(); + memoryPromise_.set_value(memory); } std::promise memoryPromise_; From 4a41c19e721ffad3e22fbf23611cb217548a93da Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 3 May 2023 19:40:23 +0000 Subject: [PATCH 47/54] Fix performance bug and base pointer offset --- src/communicator.cc | 4 +-- src/connection.cc | 1 - src/include/registered_memory.hpp | 7 ++-- src/registered_memory.cc | 57 ++++++++++++++++--------------- src/utils.cc | 9 ++++- 5 files changed, 44 insertions(+), 34 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index 469502b79..1fd641320 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -86,9 +86,7 @@ struct MemoryReceiver : public Setuppable { std::vector data; bootstrap->recv(data, remoteRank_, tag_); - auto memory = RegisteredMemory::deserialize(data); - memory.data(); - memoryPromise_.set_value(memory); + memoryPromise_.set_value(RegisteredMemory::deserialize(data)); } std::promise memoryPromise_; diff --git a/src/connection.cc b/src/connection.cc index 0dee770b4..dca3e6629 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -55,7 +55,6 @@ Transport CudaIpcConnection::remoteTransport() void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { - ScopedTimer timer("CudaIpcConnection::write"); validateTransport(dst, remoteTransport()); validateTransport(src, transport()); diff --git a/src/include/registered_memory.hpp b/src/include/registered_memory.hpp index e95507f1f..bf4802ce6 100644 --- a/src/include/registered_memory.hpp +++ b/src/include/registered_memory.hpp @@ -16,7 +16,10 @@ struct TransportInfo // TODO: rewrite this using std::variant or something bool ibLocal; union { - cudaIpcMemHandle_t cudaIpcHandle; + struct { + cudaIpcMemHandle_t cudaIpcBaseHandle; + size_t cudaIpcOffsetFromBase; + }; struct { const IbMr* ibMr; IbMrInfo ibMrInfo; @@ -27,9 +30,9 @@ struct TransportInfo struct RegisteredMemory::Impl { void* data; - bool dataInitialized; size_t size; int rank; + uint64_t hostHash; TransportFlags transports; std::vector transportInfos; diff --git a/src/registered_memory.cc b/src/registered_memory.cc index abf17a8bc..fed732a02 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -1,13 +1,14 @@ #include "registered_memory.hpp" #include "api.h" #include "checks.hpp" +#include "utils.h" #include #include namespace mscclpp { RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl) - : data(data), dataInitialized(true), size(size), rank(rank), transports(transports) + : data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) { if (transports.has(Transport::CudaIpc)) { TransportInfo transportInfo; @@ -18,7 +19,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags t size_t baseDataSize; // dummy CUTHROW(cuMemGetAddressRange((CUdeviceptr*)&baseDataPtr, &baseDataSize, (CUdeviceptr)data)); CUDATHROW(cudaIpcGetMemHandle(&handle, baseDataPtr)); - transportInfo.cudaIpcHandle = handle; + // TODO: bug with offset of base? + transportInfo.cudaIpcBaseHandle = handle; + transportInfo.cudaIpcOffsetFromBase = (char*)data - (char*)baseDataPtr; this->transportInfos.push_back(transportInfo); } if ((transports & AllIBTransports).any()) { @@ -57,24 +60,12 @@ MSCCLPP_API_CPP RegisteredMemory::RegisteredMemory(std::shared_ptr pimpl) MSCCLPP_API_CPP RegisteredMemory::~RegisteredMemory() = default; -void* RegisteredMemory::data() +MSCCLPP_API_CPP void* RegisteredMemory::data() { - if (!pimpl->dataInitialized) { - if (pimpl->transports.has(Transport::CudaIpc)) { - auto entry = pimpl->getTransportInfo(Transport::CudaIpc); - CUDATHROW(cudaIpcOpenMemHandle(&pimpl->data, entry.cudaIpcHandle, cudaIpcMemLazyEnablePeerAccess)); - INFO(MSCCLPP_P2P, "Opened CUDA IPC handle for base point of %p", pimpl->data); - } - else - { - pimpl->data = nullptr; - } - pimpl->dataInitialized = true; - } return pimpl->data; } -size_t RegisteredMemory::size() +MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; } @@ -84,7 +75,7 @@ MSCCLPP_API_CPP int RegisteredMemory::rank() return pimpl->rank; } -TransportFlags RegisteredMemory::transports() +MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; } @@ -94,6 +85,7 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() std::vector result; std::copy_n(reinterpret_cast(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result)); + std::copy_n(reinterpret_cast(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result)); std::copy_n(reinterpret_cast(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result)); if (pimpl->transportInfos.size() > std::numeric_limits::max()) { throw std::runtime_error("Too many transport info entries"); @@ -103,7 +95,9 @@ MSCCLPP_API_CPP std::vector RegisteredMemory::serialize() for (auto& entry : pimpl->transportInfos) { std::copy_n(reinterpret_cast(&entry.transport), sizeof(entry.transport), std::back_inserter(result)); if (entry.transport == Transport::CudaIpc) { - std::copy_n(reinterpret_cast(&entry.cudaIpcHandle), sizeof(entry.cudaIpcHandle), + std::copy_n(reinterpret_cast(&entry.cudaIpcBaseHandle), sizeof(entry.cudaIpcBaseHandle), + std::back_inserter(result)); + std::copy_n(reinterpret_cast(&entry.cudaIpcOffsetFromBase), sizeof(entry.cudaIpcOffsetFromBase), std::back_inserter(result)); } else if (AllIBTransports.has(entry.transport)) { std::copy_n(reinterpret_cast(&entry.ibMrInfo), sizeof(entry.ibMrInfo), std::back_inserter(result)); @@ -126,6 +120,8 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) it += sizeof(this->size); std::copy_n(it, sizeof(this->rank), reinterpret_cast(&this->rank)); it += sizeof(this->rank); + std::copy_n(it, sizeof(this->hostHash), reinterpret_cast(&this->hostHash)); + it += sizeof(this->hostHash); std::copy_n(it, sizeof(this->transports), reinterpret_cast(&this->transports)); it += sizeof(this->transports); int8_t transportCount; @@ -136,15 +132,13 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) std::copy_n(it, sizeof(transportInfo.transport), reinterpret_cast(&transportInfo.transport)); it += sizeof(transportInfo.transport); if (transportInfo.transport == Transport::CudaIpc) { - cudaIpcMemHandle_t handle; - std::copy_n(it, sizeof(handle), reinterpret_cast(&handle)); - it += sizeof(handle); - transportInfo.cudaIpcHandle = handle; + std::copy_n(it, sizeof(transportInfo.cudaIpcBaseHandle), reinterpret_cast(&transportInfo.cudaIpcBaseHandle)); + it += sizeof(transportInfo.cudaIpcBaseHandle); + std::copy_n(it, sizeof(transportInfo.cudaIpcOffsetFromBase), reinterpret_cast(&transportInfo.cudaIpcOffsetFromBase)); + it += sizeof(transportInfo.cudaIpcOffsetFromBase); } else if (AllIBTransports.has(transportInfo.transport)) { - IbMrInfo info; - std::copy_n(it, sizeof(info), reinterpret_cast(&info)); - it += sizeof(info); - transportInfo.ibMrInfo = info; + std::copy_n(it, sizeof(transportInfo.ibMrInfo), reinterpret_cast(&transportInfo.ibMrInfo)); + it += sizeof(transportInfo.ibMrInfo); transportInfo.ibLocal = false; } else { throw std::runtime_error("Unknown transport"); @@ -155,7 +149,16 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) throw std::runtime_error("Deserialization failed"); } - dataInitialized = false; + if (transports.has(Transport::CudaIpc)) { + uint64_t localHostHash = getHostHash(); + if (localHostHash == this->hostHash) { + auto entry = getTransportInfo(Transport::CudaIpc); + void* base; + CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); + data = static_cast(base) + entry.cudaIpcOffsetFromBase; + INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", data); + } + } } } // namespace mscclpp diff --git a/src/utils.cc b/src/utils.cc index ebd31bfed..6954a64fc 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -9,6 +9,7 @@ #include #include #include +#include // Get current Compute Capability // int mscclppCudaCompCap() { @@ -112,7 +113,7 @@ uint64_t getHash(const char* string, int n) * This string can be overridden by using the MSCCLPP_HOSTID env var. */ #define HOSTID_FILE "/proc/sys/kernel/random/boot_id" -uint64_t getHostHash(void) +uint64_t computeHostHash(void) { char hostHash[1024]; char* hostId; @@ -144,6 +145,12 @@ uint64_t getHostHash(void) return getHash(hostHash, strlen(hostHash)); } +uint64_t getHostHash(void) +{ + thread_local std::unique_ptr hostHash = std::make_unique(computeHostHash()); + return *hostHash; +} + /* Generate a hash of the unique identifying string for this process * that will be unique for both bare-metal and container instances * Equivalent of a hash of; From 7af687954c0b7f2efbdd3ed87d76bac1d12f753f Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Wed, 3 May 2023 20:23:51 +0000 Subject: [PATCH 48/54] removing old mscclppComm_t comm from communicator --- src/include/communicator.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/include/communicator.hpp b/src/include/communicator.hpp index 32fb6e302..5b0c7485d 100644 --- a/src/include/communicator.hpp +++ b/src/include/communicator.hpp @@ -14,7 +14,6 @@ class ConnectionBase; struct Communicator::Impl { - mscclppComm_t comm; std::vector> connections_; std::vector> toSetup_; std::unordered_map> ibContexts_; From 518f325225ccece587aba8a0874a277d22cb4cc8 Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Wed, 3 May 2023 22:45:47 +0000 Subject: [PATCH 49/54] kernel 2 is also performant --- tests/allgather_test_cpp.cu | 14 +++++++++++--- tests/communicator_test_cpp.cu | 17 ----------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/tests/allgather_test_cpp.cu b/tests/allgather_test_cpp.cu index aaff931c7..ad473f8f4 100644 --- a/tests/allgather_test_cpp.cu +++ b/tests/allgather_test_cpp.cu @@ -84,7 +84,7 @@ __device__ void localAllGather(mscclpp::channel::SimpleDeviceChannel devChan, in if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) { // put your data to GPU (rank+i) % nranksPerNode and signal in one call if ((threadIdx.x % 32) == 0) - devChan.putWithSignalAndFlush(offset, size); + devChan.putWithSignal(offset, size); } // wait for the data from GPU (rank-i) % nranksPerNode to arrive if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) { @@ -100,6 +100,9 @@ __device__ void allgather1(mscclpp::channel::SimpleDeviceChannel devChan, int ra { localAllGather(devChan, rank, world_size, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int)); + if (remoteRank / nranksPerNode == rank / nranksPerNode) + if ((threadIdx.x % 32) == 0) + devChan.flush(); } __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int rank, int world_size, int nranksPerNode, @@ -127,7 +130,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devChan.putWithSignalAndFlush(rank * nelemsPerGPU * sizeof(int), + devChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) devChan.wait(); @@ -147,7 +150,7 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra if (remoteRank % nranksPerNode == rank % nranksPerNode) { // opposite side if ((threadIdx.x % 32) == 0) - devChan.putWithSignalAndFlush((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * + devChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), nelemsPerGPU / pipelineSize * sizeof(int)); if ((threadIdx.x % 32) == 0) @@ -163,6 +166,11 @@ __device__ void allgather2(mscclpp::channel::SimpleDeviceChannel devChan, int ra (otherNghr * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int), nelemsPerGPU / pipelineSize * sizeof(int)); } + + if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) { + if ((threadIdx.x % 32) == 0) + devChan.flush(); + } } __global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel) diff --git a/tests/communicator_test_cpp.cu b/tests/communicator_test_cpp.cu index fcdd0f5a3..56c8592e8 100644 --- a/tests/communicator_test_cpp.cu +++ b/tests/communicator_test_cpp.cu @@ -39,23 +39,6 @@ void register_all_memories(mscclpp::Communicator& communicator, int rank, int wo remoteMemory[i] = futureRemoteMemory[i].get(); } } - - - // auto serialized = localMemory.serialize(); - // int serializedSize = serialized.size(); - // for (int i = 0; i < worldSize; i++) { - // if (i != rank){ - // communicator.bootstrapper()->send(serialized.data(), serializedSize, i, 0); - // } - // } - // for (int i = 0; i < worldSize; i++) { - // if (i != rank){ - // std::vector deserialized(serializedSize); - // communicator.bootstrapper()->recv(deserialized.data(), serializedSize, i, 0); - // auto remote = mscclpp::RegisteredMemory::deserialize(deserialized); - // remoteMemory[i] = remote; - // } - // } } void make_connections(mscclpp::Communicator& communicator, int rank, int worldSize, int nRanksPerNode, mscclpp::Transport myIbDevice, std::unordered_map>& connections){ From 503cdd5c7ee693e6b16a1aa75fa85d86321bf0f0 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 13 Apr 2023 00:23:25 +0000 Subject: [PATCH 50/54] CMake build system transition WIP --- CMakeLists.txt | 23 ++++++++++++++++++ cmake/modules/FindGDRCopy.cmake | 41 +++++++++++++++++++++++++++++++++ cmake/modules/FindIBVerbs.cmake | 41 +++++++++++++++++++++++++++++++++ cmake/modules/FindNUMA.cmake | 41 +++++++++++++++++++++++++++++++++ src/CMakeLists.txt | 12 ++++++++++ tests/CMakeLists.txt | 5 ++++ 6 files changed, 163 insertions(+) create mode 100644 CMakeLists.txt create mode 100644 cmake/modules/FindGDRCopy.cmake create mode 100644 cmake/modules/FindIBVerbs.cmake create mode 100644 cmake/modules/FindNUMA.cmake create mode 100644 src/CMakeLists.txt create mode 100644 tests/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..87c9c24ee --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.26) + +project(mscclpp LANGUAGES CUDA CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules) + +find_package(CUDAToolkit REQUIRED) +find_package(IBVerbs REQUIRED) +find_package(NUMA REQUIRED) +find_package(GDRCopy) + +option(USE_MPI_FOR_TESTS "Use MPI for tests" ON) +if(USE_MPI_FOR_TESTS) + find_package(MPI REQUIRED) +endif() + +include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + +add_subdirectory(src) +add_subdirectory(tests) \ No newline at end of file diff --git a/cmake/modules/FindGDRCopy.cmake b/cmake/modules/FindGDRCopy.cmake new file mode 100644 index 000000000..cde447bad --- /dev/null +++ b/cmake/modules/FindGDRCopy.cmake @@ -0,0 +1,41 @@ +# Find the GDRCopy libraries +# +# The following variables are optionally searched for defaults +# GDRCOPY_ROOT_DIR: Base directory where all GDRCopy components are found +# GDRCOPY_INCLUDE_DIR: Directory where GDRCopy headers are found +# GDRCOPY_LIB_DIR: Directory where GDRCopy libraries are found + +# The following are set after configuration is done: +# GDRCOPY_FOUND +# GDRCOPY_INCLUDE_DIRS +# GDRCOPY_LIBRARIES + +# An imported target MSCCLPP::gdrcopy is created if the library is found. + +find_path(GDRCOPY_INCLUDE_DIRS + NAMES gdrapi.h + HINTS + ${GDRCOPY_INCLUDE_DIR} + ${GDRCOPY_ROOT_DIR} + ${GDRCOPY_ROOT_DIR}/include) + +find_library(GDRCOPY_LIBRARIES + NAMES gdrapi + HINTS + ${GDRCOPY_LIB_DIR} + ${GDRCOPY_ROOT_DIR} + ${GDRCOPY_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(GDRCopy DEFAULT_MSG GDRCOPY_INCLUDE_DIRS GDRCOPY_LIBRARIES) +mark_as_advanced(GDRCOPY_INCLUDE_DIR GDRCOPY_LIBRARIES) + +if(GDRCOPY_FOUND) + if(NOT TARGET MSCCLPP::gdrcopy) + add_library(MSCCLPP::gdrcopy UNKNOWN IMPORTED) + endif() + set_target_properties(MSCCLPP::gdrcopy PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GDRCOPY_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${GDRCOPY_LIBRARIES}") +endif() \ No newline at end of file diff --git a/cmake/modules/FindIBVerbs.cmake b/cmake/modules/FindIBVerbs.cmake new file mode 100644 index 000000000..fc80b11c0 --- /dev/null +++ b/cmake/modules/FindIBVerbs.cmake @@ -0,0 +1,41 @@ +# Find the IB Verbs libraries +# +# The following variables are optionally searched for defaults +# IBVERBS_ROOT_DIR: Base directory where all ibverbs components are found +# IBVERBS_INCLUDE_DIR: Directory where ibverbs headers are found +# IBVERBS_LIB_DIR: Directory where ibverbs libraries are found + +# The following are set after configuration is done: +# IBVERBS_FOUND +# IBVERBS_INCLUDE_DIRS +# IBVERBS_LIBRARIES + +# An imported target MSCCLPP::ibverbs is created if the library is found. + +find_path(IBVERBS_INCLUDE_DIRS + NAMES infiniband/verbs.h + HINTS + ${IBVERBS_INCLUDE_DIR} + ${IBVERBS_ROOT_DIR} + ${IBVERBS_ROOT_DIR}/include) + +find_library(IBVERBS_LIBRARIES + NAMES ibverbs + HINTS + ${IBVERBS_LIB_DIR} + ${IBVERBS_ROOT_DIR} + ${IBVERBS_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(IBVerbs DEFAULT_MSG IBVERBS_INCLUDE_DIRS IBVERBS_LIBRARIES) +mark_as_advanced(IBVERBS_INCLUDE_DIR IBVERBS_LIBRARIES) + +if(IBVERBS_FOUND) + if(NOT TARGET MSCCLPP::ibverbs) + add_library(MSCCLPP::ibverbs UNKNOWN IMPORTED) + endif() + set_target_properties(MSCCLPP::ibverbs PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${IBVERBS_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${IBVERBS_LIBRARIES}") +endif() \ No newline at end of file diff --git a/cmake/modules/FindNUMA.cmake b/cmake/modules/FindNUMA.cmake new file mode 100644 index 000000000..70e04d536 --- /dev/null +++ b/cmake/modules/FindNUMA.cmake @@ -0,0 +1,41 @@ +# Find the numa libraries +# +# The following variables are optionally searched for defaults +# NUMA_ROOT_DIR: Base directory where all numa components are found +# NUMA_INCLUDE_DIR: Directory where numa headers are found +# NUMA_LIB_DIR: Directory where numa libraries are found + +# The following are set after configuration is done: +# NUMA_FOUND +# NUMA_INCLUDE_DIRS +# NUMA_LIBRARIES + +# An imported target MSCCLPP::numa is created if the library is found. + +find_path(NUMA_INCLUDE_DIRS + NAMES numa.h + HINTS + ${NUMA_INCLUDE_DIR} + ${NUMA_ROOT_DIR} + ${NUMA_ROOT_DIR}/include) + +find_library(NUMA_LIBRARIES + NAMES numa + HINTS + ${NUMA_LIB_DIR} + ${NUMA_ROOT_DIR} + ${NUMA_ROOT_DIR}/lib) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NUMA DEFAULT_MSG NUMA_INCLUDE_DIRS NUMA_LIBRARIES) +mark_as_advanced(NUMA_INCLUDE_DIR NUMA_LIBRARIES) + +if(NUMA_FOUND) + if(NOT TARGET MSCCLPP::numa) + add_library(MSCCLPP::numa UNKNOWN IMPORTED) + endif() + set_target_properties(MSCCLPP::numa PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${NUMA_LIBRARIES}") +endif() \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 000000000..f6bf1bc32 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.h) +file(GLOB to_remove gdr.cc) +list(REMOVE_ITEM SOURCES ${to_remove}) + +add_library(mscclpp SHARED ${SOURCES}) +set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart) +if(GDRCOPY_FOUND) + target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy) +endif() + +target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 000000000..669c669df --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(bootstrap_test bootstrap_test.cc) +target_link_libraries(bootstrap_test mscclpp) + +add_executable(allgather_test_standalone allgather_test_standalone.cu) +target_link_libraries(allgather_test_standalone mscclpp) \ No newline at end of file From 09d5f7c12ec6b0487d415ad10aa8eac086c58f58 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 4 May 2023 00:39:30 +0000 Subject: [PATCH 51/54] Fixes for cmake --- CMakeLists.txt | 1 + src/CMakeLists.txt | 2 +- tests/CMakeLists.txt | 7 +++++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 87c9c24ee..68fa1b84b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ find_package(GDRCopy) option(USE_MPI_FOR_TESTS "Use MPI for tests" ON) if(USE_MPI_FOR_TESTS) find_package(MPI REQUIRED) + add_definitions(-DMSCCLPP_USE_MPI_FOR_TESTS) endif() include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f6bf1bc32..1d989c6b7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,7 +4,7 @@ list(REMOVE_ITEM SOURCES ${to_remove}) add_library(mscclpp SHARED ${SOURCES}) set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart) +target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver) if(GDRCOPY_FOUND) target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy) endif() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 669c669df..fd02e658d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,8 @@ add_executable(bootstrap_test bootstrap_test.cc) -target_link_libraries(bootstrap_test mscclpp) +target_link_libraries(bootstrap_test mscclpp MPI::MPI_CXX) + +add_executable(allgather_test_cpp allgather_test_cpp.cu) +target_link_libraries(allgather_test_cpp mscclpp MPI::MPI_CXX) add_executable(allgather_test_standalone allgather_test_standalone.cu) -target_link_libraries(allgather_test_standalone mscclpp) \ No newline at end of file +target_link_libraries(allgather_test_standalone mscclpp MPI::MPI_CXX) From bd2121a2efa53d3ab4104f53cae24c71981dda2d Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 4 May 2023 00:53:50 +0000 Subject: [PATCH 52/54] CMake improvement --- CMakeLists.txt | 12 ++++++++++-- src/CMakeLists.txt | 9 +-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 68fa1b84b..81f99cdb6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,5 +20,13 @@ endif() include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -add_subdirectory(src) -add_subdirectory(tests) \ No newline at end of file +add_library(mscclpp SHARED) +add_subdirectory(src) # This adds the srouces to the mscclpp target +target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/src/include) +set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver) +if(GDRCOPY_FOUND) + target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy) +endif() + +add_subdirectory(tests) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1d989c6b7..5e583d455 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,11 +2,4 @@ file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.h) file(GLOB to_remove gdr.cc) list(REMOVE_ITEM SOURCES ${to_remove}) -add_library(mscclpp SHARED ${SOURCES}) -set_target_properties(mscclpp PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(mscclpp PRIVATE MSCCLPP::ibverbs MSCCLPP::numa CUDA::cudart CUDA::cuda_driver) -if(GDRCOPY_FOUND) - target_link_libraries(mscclpp PRIVATE MSCCLPP::gdrcopy) -endif() - -target_include_directories(mscclpp PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) \ No newline at end of file +target_sources(mscclpp PRIVATE ${SOURCES}) From d7103602acfa21723d022adc47d9dbc97057b80d Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 4 May 2023 00:55:35 +0000 Subject: [PATCH 53/54] Only build C++ tests in CMake --- tests/CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fd02e658d..457003e3a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,8 +1,8 @@ -add_executable(bootstrap_test bootstrap_test.cc) -target_link_libraries(bootstrap_test mscclpp MPI::MPI_CXX) +add_executable(bootstrap_test_cpp bootstrap_test_cpp.cc) +target_link_libraries(bootstrap_test_cpp mscclpp MPI::MPI_CXX) + +add_executable(communicator_test_cpp communicator_test_cpp.cu) +target_link_libraries(communicator_test_cpp mscclpp MPI::MPI_CXX) add_executable(allgather_test_cpp allgather_test_cpp.cu) target_link_libraries(allgather_test_cpp mscclpp MPI::MPI_CXX) - -add_executable(allgather_test_standalone allgather_test_standalone.cu) -target_link_libraries(allgather_test_standalone mscclpp MPI::MPI_CXX) From ddc9e681c8428a13ebd5f4f24cd4f4d4af3c3e31 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Thu, 4 May 2023 00:57:34 +0000 Subject: [PATCH 54/54] Add ib_test to CMake --- tests/CMakeLists.txt | 2 ++ tests/unittests/CMakeLists.txt | 2 ++ 2 files changed, 4 insertions(+) create mode 100644 tests/unittests/CMakeLists.txt diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 457003e3a..b6ee63c7b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,3 +6,5 @@ target_link_libraries(communicator_test_cpp mscclpp MPI::MPI_CXX) add_executable(allgather_test_cpp allgather_test_cpp.cu) target_link_libraries(allgather_test_cpp mscclpp MPI::MPI_CXX) + +add_subdirectory(unittests) diff --git a/tests/unittests/CMakeLists.txt b/tests/unittests/CMakeLists.txt new file mode 100644 index 000000000..85f87f526 --- /dev/null +++ b/tests/unittests/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(ib_test ib_test.cc) +target_link_libraries(ib_test mscclpp MPI::MPI_CXX CUDA::cudart) \ No newline at end of file