diff --git a/CMakeLists.txt b/CMakeLists.txt index 7782daba1f..4aeabb97f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,16 @@ if(USE_XCCL) caffe2_update_option(USE_C10D_XCCL OFF) update_caffe2_macros_file() endif() + if(USE_ISHMEM) + include(${TORCH_XPU_OPS_ROOT}/cmake/ISHMEM.cmake) + if(NOT PYTORCH_FOUND_ISHMEM) + message(WARNING "ISHMEM not found, disabling ISHMEM support") + caffe2_update_option(USE_ISHMEM OFF) + update_caffe2_macros_file() + else() + message(STATUS "ISHMEM support enabled") + endif() + endif() endif() set(USE_SYCLTLA ON) diff --git a/cmake/ISHMEM.cmake b/cmake/ISHMEM.cmake new file mode 100644 index 0000000000..2319a71baa --- /dev/null +++ b/cmake/ISHMEM.cmake @@ -0,0 +1,24 @@ +if(NOT __ISHMEM_INCLUDED) + set(__ISHMEM_INCLUDED TRUE) + + # ISHMEM_ROOT, ISHMEM_LIBRARY_DIR, ISHMEM_INCLUDE_DIR are handled by FindISHMEM.cmake. + find_package(ISHMEM REQUIRED) + if(NOT ISHMEM_FOUND) + set(PYTORCH_FOUND_ISHMEM FALSE) + message(WARNING "${ISHMEM_NOT_FOUND_MESSAGE}") + return() + endif() + + set(PYTORCH_FOUND_ISHMEM TRUE) + add_library(torch::ishmem INTERFACE IMPORTED) + set_property( + TARGET torch::ishmem PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${ISHMEM_INCLUDE_DIR}) + set_property( + TARGET torch::ishmem PROPERTY INTERFACE_LINK_LIBRARIES + ${ISHMEM_LIBRARY}) + + message(STATUS "Found Intel SHMEM: ${ISHMEM_ROOT}") + message(STATUS " ISHMEM include dir: ${ISHMEM_INCLUDE_DIR}") + message(STATUS " ISHMEM library: ${ISHMEM_LIBRARY}") +endif() diff --git a/cmake/Modules/FindISHMEM.cmake b/cmake/Modules/FindISHMEM.cmake new file mode 100644 index 0000000000..96c7656b85 --- /dev/null +++ b/cmake/Modules/FindISHMEM.cmake @@ -0,0 +1,65 @@ +# This will define the following variables: +# ISHMEM_FOUND : True if the system has the ISHMEM library. +# ISHMEM_INCLUDE_DIR : Include directories needed to use ISHMEM. +# ISHMEM_LIBRARY_DIR : The path to the ISHMEM library. +# ISHMEM_LIBRARY : ISHMEM library fullname. + +include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + +if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux") + set(ISHMEM_FOUND False) + set(ISHMEM_NOT_FOUND_MESSAGE "Intel SHMEM library is only supported on Linux!") + return() +endif() + +set(ISHMEM_ROOT $ENV{ISHMEM_ROOT}) + +if(NOT ISHMEM_ROOT) + set(ISHMEM_FOUND False) + set(ISHMEM_NOT_FOUND_MESSAGE "ISHMEM_ROOT environment variable not set. Please set it to your ISHMEM installation directory.") + return() +endif() + +# Find include path from binary. +find_path( + ISHMEM_INCLUDE_DIR + NAMES ishmem.h + HINTS ${ISHMEM_ROOT}/include + NO_DEFAULT_PATH +) + +# Find library directory from binary. +find_path( + ISHMEM_LIBRARY_DIR + NAMES libishmem.a + HINTS ${ISHMEM_ROOT}/lib + NO_DEFAULT_PATH +) + +# Find ISHMEM library fullname (static library). +find_library( + ISHMEM_LIBRARY + NAMES ishmem + HINTS ${ISHMEM_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT ISHMEM_INCLUDE_DIR) OR (NOT ISHMEM_LIBRARY_DIR) OR (NOT ISHMEM_LIBRARY)) + set(ISHMEM_FOUND False) + set(ISHMEM_NOT_FOUND_MESSAGE "Intel SHMEM library not found! Please set ISHMEM_ROOT environment variable.") + return() +endif() + +SET(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${ISHMEM_INCLUDE_DIR}") +SET(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${ISHMEM_LIBRARY_DIR}") + +find_package_handle_standard_args( + ISHMEM + FOUND_VAR ISHMEM_FOUND + REQUIRED_VARS ISHMEM_INCLUDE_DIR ISHMEM_LIBRARY_DIR ISHMEM_LIBRARY + REASON_FAILURE_MESSAGE "${ISHMEM_NOT_FOUND_MESSAGE}" +) + +mark_as_advanced(ISHMEM_INCLUDE_DIR ISHMEM_LIBRARY_DIR ISHMEM_LIBRARY) diff --git a/src/BuildOnLinux.cmake b/src/BuildOnLinux.cmake index 8a26c8d149..a775d95fa6 100644 --- a/src/BuildOnLinux.cmake +++ b/src/BuildOnLinux.cmake @@ -17,6 +17,9 @@ macro(setup_common_libraries) target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) + if(USE_ISHMEM AND PYTORCH_FOUND_ISHMEM) + target_link_libraries(torch_xpu_ops PUBLIC torch::ishmem) + endif() endif() if(USE_SYCLTLA) @@ -50,6 +53,9 @@ else() target_compile_definitions(torch_xpu_ops PRIVATE USE_C10D_XCCL) target_link_libraries(torch_xpu_ops PUBLIC torch::xccl) target_link_libraries(torch_xpu_ops PUBLIC fmt::fmt-header-only) + if(USE_ISHMEM AND PYTORCH_FOUND_ISHMEM) + target_link_libraries(torch_xpu_ops PUBLIC torch::ishmem) + endif() endif() if(USE_SYCLTLA) diff --git a/src/xccl/ISHMEMSymmetricMemory.cpp b/src/xccl/ISHMEMSymmetricMemory.cpp new file mode 100644 index 0000000000..989960d52d --- /dev/null +++ b/src/xccl/ISHMEMSymmetricMemory.cpp @@ -0,0 +1,446 @@ +#include +#include "XPUSymmetricMemoryUtils.hpp" + +#include +#include +#include +#include +// Include ISHMEM headers - directly link to static library +#include +#include + +namespace c10d { +namespace symmetric_memory { + +/* Start of ISHMEMSymmetricMemory implementation */ + +// XPU-specific constants for symmetric memory +// Intel Data Center GPU Max can support up to 8 GPUs in a single node +constexpr int max_xpu_p2p_domain_size = 8; +// Maximum number of channels (same as CUDA) +constexpr int xpu_symm_max_nblocks = 32; +// Signal pad size for XPU +constexpr size_t xpu_signal_pad_size = + xpu_symm_max_nblocks * max_xpu_p2p_domain_size * sizeof(uint32_t); + +static StoreExchange storeExchange = StoreExchange("ISHMEMSymmetricMemory"); + +struct ISHMEMAllocation { + void* ptr; + size_t buffer_size; + int device_idx; + + ISHMEMAllocation(void* ptr, size_t buffer_size, int device_idx) + : ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {} + + ~ISHMEMAllocation() { + // Avoid calling XPU functions after driver shutting down + if (is_finalizing()) { + return; + } + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + ishmem_free(ptr); + } +}; + +// A class to hold the base pointers and signal pad pointers for a group of +// peers. One `ISHMEMPeerAllocInfo` object can be shared by multiple +// `ISHMEMSymmetricMemory` objects when latter reside on the same allocation +// and rendezvous over the same group. (The `ISHMEMSymmetricMemory` objects may +// have different offsets compared to the base address.) +class ISHMEMPeerAllocInfo : public c10::intrusive_ptr_target { + public: + ISHMEMPeerAllocInfo( + ISHMEMAllocation* allocation, + const std::string& group_name) + : base_ptr_(allocation->ptr), buffer_size_(allocation->buffer_size) { + // For logging only + static int exchanged_n_times = 0; + + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, allocation->device_idx)); + + auto global_rank = get_group_info("0").rank; + GroupInfo& group_info = get_group_info(group_name); + auto store = group_info.store; + rank_ = group_info.rank; + world_size_ = group_info.world_size; + // Exchange rank to global rank mapping for this group. + // If it is already available, skip the exchange. + if (group_info.rank_to_global_rank.empty()) { + group_info.rank_to_global_rank = + storeExchange.all_gather(store, rank_, world_size_, global_rank); + exchanged_n_times++; + if (rank_ == 0) { + LOG(INFO) << "[rank " << rank_ << ']' + << " rank_to_global_rank: " << group_info.rank_to_global_rank + << ", group_name: " << group_name + << ", exchanged_n_times: " << exchanged_n_times; + } + } + TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty()); + rank_to_global_rank_ = group_info.rank_to_global_rank; + + world_within_xpu_p2p_ = true; + for (int r = 0; r < world_size_; ++r) { + auto peer_ptr = ishmem_ptr(base_ptr_, rank_to_global_rank_[r]); + buffers_.push_back(peer_ptr); + // If a peer is over network, `ishmem_ptr` returns null + if (peer_ptr == nullptr) { + world_within_xpu_p2p_ = false; + } + } + + // TODO: use the same allocation for signal pad + void* signal_pad_ptr = ishmem_malloc(xpu_signal_pad_size); + TORCH_CHECK(signal_pad_ptr != nullptr, "ishmem_malloc failed"); + + // Use SYCL queue to initialize signal pad memory + auto& queue = at::xpu::getCurrentSYCLQueue(); + queue.memset(signal_pad_ptr, 0, xpu_signal_pad_size).wait(); + + for (int r = 0; r < world_size_; ++r) { + signal_pads_.push_back( + ishmem_ptr(signal_pad_ptr, rank_to_global_rank_[r])); + } + + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + + queue.memcpy(buffers_dev_, buffers_.data(), arr_size).wait(); + queue.memcpy(signal_pads_dev_, signal_pads_.data(), arr_size).wait(); + + rank_to_global_rank_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(sizeof(int) * world_size_)); + queue + .memcpy( + rank_to_global_rank_dev_, + rank_to_global_rank_.data(), + sizeof(int) * world_size_) + .wait(); + } + + private: + void* base_ptr_; + size_t buffer_size_; + int rank_; + int world_size_; + std::vector buffers_; + std::vector signal_pads_; + void** buffers_dev_; + void** signal_pads_dev_; + std::vector rank_to_global_rank_; + int* rank_to_global_rank_dev_; + // Whether the world is within XPU P2P only, not network + bool world_within_xpu_p2p_; + + friend class ISHMEMSymmetricMemory; +}; + +class ISHMEMSymmetricMemory : public SymmetricMemory { + public: + ISHMEMSymmetricMemory( + ISHMEMAllocation* allocation, + const std::string& group_name) + : device_idx_(allocation->device_idx), group_name_(group_name) { + // A handle stores two types of info: + // (i) allocation's base ptrs and base signal pads, ours and peers' + pai_ = c10::make_intrusive(allocation, group_name); + // (ii) offset of tensor compared to base ptr (in byte) + offset_ = 0; + } + + // Exact copy is not needed / supported + ISHMEMSymmetricMemory(const ISHMEMSymmetricMemory& other) = delete; + + // Copy with offset is allowed + // This is mostly a shallow copy that shares the pointer to + // `ISHMEMPeerAllocInfo` which has been created by `other` + ISHMEMSymmetricMemory(const ISHMEMSymmetricMemory& other, size_t offset) + : device_idx_(other.device_idx_), + group_name_(other.group_name_), + pai_(other.pai_) { + offset_ = offset; + } + + ~ISHMEMSymmetricMemory() override{ + // TODO + }; + + std::vector get_buffer_ptrs() override { + return pai_->buffers_; + } + + std::vector get_signal_pad_ptrs() override { + return pai_->signal_pads_; + } + + void** get_buffer_ptrs_dev() override { + return pai_->buffers_dev_; + } + + void** get_signal_pad_ptrs_dev() override { + return pai_->signal_pads_dev_; + } + + size_t get_buffer_size() override { + return pai_->buffer_size_; + } + + size_t get_signal_pad_size() override { + return xpu_signal_pad_size; + }; + + bool has_multicast_support() override { + // ISHMEM does not have multicast support + return false; + } + + void* get_multicast_ptr() override { + // ISHMEM does not have multicast support + return nullptr; + } + + size_t get_offset() override { + return offset_; + } + + void barrier(int channel, size_t timeout_ms) override { + // Use ISHMEM barrier + ishmem_barrier_all(); + } + + void put_signal(int dst_rank, int channel, size_t timeout_ms) override { + // TODO: Implement signal mechanism for ISHMEM + // ISHMEM uses different signaling approach than NVSHMEM + } + + void wait_signal(int src_rank, int channel, size_t timeout_ms) override { + // TODO: Implement signal mechanism for ISHMEM + } + + int get_rank() override { + return pai_->rank_; + } + + int get_world_size() override { + return pai_->world_size_; + } + + c10::Device get_device() override { + return c10::Device(c10::DeviceType::XPU, device_idx_); + } + + const std::vector& get_rank_to_global_rank() override { + return pai_->rank_to_global_rank_; + }; + + int* get_rank_to_global_rank_dev() override { + return pai_->rank_to_global_rank_dev_; + }; + + bool world_within_direct_access() override { + return pai_->world_within_xpu_p2p_; + } + + private: + int device_idx_; + std::string group_name_; + c10::intrusive_ptr pai_; + size_t offset_{0}; // in byte +}; + +static void initialize_ishmem_with_store( + c10::intrusive_ptr store, + int rank, + int world_size, + int device_idx) { + static bool is_initialized = false; + if (is_initialized) { + return; + } + + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + + ishmemx_uniqueid_t unique_id; + if (rank == 0) { + // Root rank generates the unique ID + int ret = ishmemx_get_uniqueid(&unique_id); + TORCH_CHECK(ret == 0, "ishmemx_get_uniqueid failed with error: ", ret); + } + + auto unique_ids = + storeExchange.all_gather(store, rank, world_size, unique_id); + + // Initialize ISHMEM with attributes using unique ID + ishmemx_attr_t attr; + attr.initialize_runtime = false; // MPI/OpenSHMEM backend already initialized + attr.use_uid = true; + attr.nranks = world_size; + attr.uid = &unique_ids[0]; + + // ishmemx_init_attr returns void, not int + ishmemx_init_attr(&attr); + // Verify initialization succeeded by checking PE info + TORCH_CHECK( + ishmem_my_pe() == rank, + "ISHMEM initialization failed: rank mismatch, expected ", + rank, + " got ", + ishmem_my_pe()); + + is_initialized = true; + + // Print version + int major, minor; + ishmem_info_get_version(&major, &minor); + LOG(INFO) << "ISHMEM initialized with unique ID, version: " << major << '.' + << minor << ", rank: " << rank << "/" << world_size; +} + +class ISHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override { + TORCH_CHECK( + group_name == std::nullopt, + "ISHMEMSymmetricMemoryAllocator::alloc " + "must not be called with a group_name"); + c10::OptionalDeviceGuard guard; + guard.reset_device(at::Device(at::DeviceType::XPU, device_idx)); + + auto group_info = get_group_info("0"); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + + initialize_ishmem_with_store(store, rank, world_size, device_idx); + auto ptr = ishmem_malloc(size); + // If size is 0 (which is legal allocation request) we shouldn't error out + TORCH_CHECK(ptr != nullptr || size == 0, "ishmem_malloc failed"); + // TODO: thread safety + allocations_.try_emplace( + ptr, std::make_unique(ptr, size, device_idx)); + return ptr; + } + + void free(void* ptr) override { + // TODO: thread safety + allocations_.erase(ptr); + }; + + size_t get_alloc_size(void* ptr) override { + auto it = allocations_.find(ptr); + if (it == allocations_.end()) { + TORCH_CHECK( + false, ptr, " is not allocated with ISHMEMSymmetricMemoryAllocator"); + } + return it->second->buffer_size; + }; + + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override { + TORCH_CHECK(group_name.has_value()); + { + auto it = symm_mems_.find(std::make_tuple(ptr, *group_name)); + if (it != symm_mems_.end()) { + return it->second; + } + } + // In case of MemPool, tensor.storage().data_ptr() may not match + // exactly an allocation's base address. Thus we perform the search by + // testing if the former is within an allocation's range. + auto alloc_it = std::find_if( + allocations_.begin(), allocations_.end(), [&](const auto& pair) { + auto& allocation = pair.second; + auto ptr_int = reinterpret_cast(ptr); + auto base_ptr = reinterpret_cast(allocation->ptr); + return ptr_int >= base_ptr && + ptr_int < base_ptr + allocation->buffer_size; + }); + TORCH_CHECK( + alloc_it != allocations_.end(), + "Pointer not within any SymmetricMemory allocation, " + "is the tensor allocated from SymmetricMemory?"); + + auto& allocation = alloc_it->second; + + // Search again using allocation base ptr (which is the key we use for + // caching, see below) + auto it = symm_mems_.find(std::make_tuple(allocation->ptr, *group_name)); + c10::intrusive_ptr symm_mem; + if (it != symm_mems_.end()) { + // Base allocation has been rendezvoused + symm_mem = it->second; + } else { + // Create a new rendezvous + symm_mem = c10::make_intrusive( + allocation.get(), *group_name); + } + + // Cache rendezvous using allocation's base address as key + symm_mems_[std::make_tuple(allocation->ptr, *group_name)] = symm_mem; + + // TODO: change the `ptr` below to `tensor.data_ptr()` when adding support + // for user slice/view operations. For MemPool support, + // `tensor.storate().data_ptr()` is fine (today's `ptr`). + + // If the tensor's ptr happen to be the same as allocation ptr + if (ptr == allocation->ptr) { + return symm_mem; + } else { + // Return a copy of the SymmetricMemory with an offset. This is a + // "shallow" copy adjusting the offset field in the handle. + return c10::make_intrusive( + *symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr); + } + }; + + bool has_multicast_support(int device_idx) override { + // ISHMEM does not have multicast support + return false; + }; + + c10::DeviceType supported_device_type() override { + return c10::DeviceType::XPU; + } + + std::string name() override { + return "ISHMEM"; + } + + private: + std::unordered_map> allocations_; + std::map< + std::tuple, + c10::intrusive_ptr> + symm_mems_; +}; + +struct RegisterISHMEMSymmetricMemoryAllocator { + RegisterISHMEMSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Query backend used for XPU tensor + // Check TORCH_SYMMMEM environment variable + if (getSymmMemBackendXPU() == "ISHMEM") { + // Direct set (static registration) + register_allocator(c10::DeviceType::XPU, allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("ISHMEM", allocator); + } + } +}; + +static RegisterISHMEMSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp new file mode 100644 index 0000000000..ecc1286318 --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include + +#include + +#include + +namespace c10d::symmetric_memory { + +// Query environment variable to get the backend used for CUDA Symmetric Memory. +std::string getSymmMemBackendXPU() { + // TORCH_SYMMMEM environment variable can be used to indicate the preferred + // backend. + static auto val = c10::utils::get_env("TORCH_SYMMMEM"); + if (val.has_value()) { + TORCH_CHECK( + val.value() == "XPU" || val.value() == "ISHMEM", + "TORCH_SYMMMEM environment variable must be one of 'XPU', 'ISHMEM'.") + return val.value(); + } + return "XPU"; +} + +class IpcChannel { + public: + IpcChannel(); + ~IpcChannel(); + + void send_fd(int dst_pid, int fd); + int recv_fd(); + + std::vector all_gather_fds( + int rank, + const std::vector& pids, + int fd); + + int broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd); + + private: + static std::string get_socket_name(int pid); + + std::string socket_name_; + int socket_; +}; + +IpcChannel::IpcChannel() + : socket_name_(get_socket_name(getpid())), + socket_(socket(AF_UNIX, SOCK_DGRAM, 0)) { + // On success, a file descriptor for the new socket is returned. + // On error, -1 is returned, and errno is set to indicate the error. + TORCH_CHECK( + socket_ != -1, "Failed to create socket: ", c10::utils::str_error(errno)); + + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); + + TORCH_CHECK( + bind(socket_, (struct sockaddr*)&addr, SUN_LEN(&addr)) == 0, + "Failed to bind socket: ", + c10::utils::str_error(errno)); +} + +IpcChannel::~IpcChannel() { + close(socket_); + unlink(socket_name_.c_str()); +} + +void IpcChannel::send_fd(int dst_pid, int fd) { + // Because file descriptors are process-local kernel objects, and we can’t + // pass them via normal socket payloads (like write() or send()). Unix domain + // sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg(). + // Define destination socket address + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + auto socket_name = get_socket_name(dst_pid); + std::copy(socket_name.begin(), socket_name.end(), addr.sun_path); + + // Prepare data to send + // Data being sent is "fd", the value of fd will be sent as auxiliary data + // (control message) + struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2}; + + // Prepare control message data buffer and zero it out + // NOLINTNEXTLINE(*array*) + char cbuf[CMSG_SPACE(sizeof(int))]; + memset(cbuf, 0, sizeof(cbuf)); + + // Create message header + struct msghdr msg{ + // destination socket address and size of it + // message content in msg_iov and number of such structs (1 in our case) + // auxiliary data with the value of fd and size of it + .msg_name = (void*)&addr, + .msg_namelen = sizeof(struct sockaddr_un), + .msg_iov = &io, + .msg_iovlen = 1, + .msg_control = cbuf, + .msg_controllen = sizeof(cbuf)}; + + // This points to the first control message header + // With SCM_RIGHTS we let the kernel know that we are passing file + // descriptors. + auto cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + // Specify socket level message + cmsg->cmsg_level = SOL_SOCKET; + // SCM_RIGHTS is the type used to pass file descriptors + cmsg->cmsg_type = SCM_RIGHTS; + + if (fd != -1) { + std::copy( + reinterpret_cast(&fd), + reinterpret_cast(&fd) + sizeof(fd), + reinterpret_cast(CMSG_DATA(cmsg))); + } else { + msg.msg_controllen = 0; + } + + // Finally send the message + TORCH_CHECK( + sendmsg(socket_, &msg, 0) > 0, + "Failed to send fd: ", + c10::utils::str_error(errno)); +} + +int IpcChannel::recv_fd() { + // Prepare buffer for regular message "fd" + // NOLINTNEXTLINE(*array*) + char buf[2]; + memset(&buf, 0, sizeof(buf)); + struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)}; + + // Prepare buffer for control message and zero it out + // NOLINTNEXTLINE(*array*) + char cbuf[CMSG_SPACE(sizeof(int))]; + memset(cbuf, 0, sizeof(cbuf)); + + // Define socket address to receive on: family AF_UNIX means unix domain + // socket + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path); + + // Prepare message header + struct msghdr msg = { + .msg_name = (void*)&addr, + .msg_namelen = sizeof(struct sockaddr_un), + .msg_iov = &io, + .msg_iovlen = 1, + .msg_control = cbuf, + .msg_controllen = sizeof(cbuf)}; + + // Receive message on socket_ + TORCH_CHECK( + recvmsg(socket_, &msg, 0) > 0, + "Failed to receive fd: ", + c10::utils::str_error(errno)); + + if (msg.msg_controllen == 0) { + return -1; + } + + // Extract control message and validate its content + auto cmsg = CMSG_FIRSTHDR(&msg); + TORCH_CHECK(cmsg != nullptr); + TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int))); + TORCH_CHECK(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS); + return *reinterpret_cast(CMSG_DATA(cmsg)); +} + +std::vector IpcChannel::all_gather_fds( + int rank, + const std::vector& pids, + int fd) { + int world_size = static_cast(pids.size()); + std::vector fds(pids.size()); + fds[rank] = fd; + + int dst_rank = (rank + 1) % world_size; + for (int step = 1; step < world_size; ++step) { + int src_rank = (rank + world_size - step) % world_size; + send_fd(pids[dst_rank], fd); + fd = recv_fd(); + fds[src_rank] = fd; + } + return fds; +} + +int IpcChannel::broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd) { + int world_size = static_cast(pids.size()); + + if (rank == src_rank) { + for (int dst_rank = 0; dst_rank < world_size; ++dst_rank) { + if (dst_rank == rank) { + continue; + } + send_fd(pids[dst_rank], fd); + } + return fd; + } + return recv_fd(); +} + +std::string IpcChannel::get_socket_name(int pid) { + const char* tmp_dir = "/tmp"; + for (const char* env_var : {"TMPDIR", "TMP", "TEMP", "TEMPDIR"}) { + if (const char* path = getenv(env_var)) { + tmp_dir = path; + break; + } + } + std::ostringstream oss; + oss << tmp_dir << "/symm_mem-" << pid; + return oss.str(); +} + +class StoreExchange { + public: + StoreExchange(const std::string& store_prefix) + : store_prefix_(store_prefix) {} + + // Put template function in header file so that compiler can easily access it. + template + std::vector all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + peer_keys.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_prefix_ << "/" << seq_id_ << "/" << r; + peer_keys.push_back(oss.str()); + } + ++seq_id_; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + peer_vals.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; + } + + void barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + // TODO: implement an efficient one? + all_gather(store, rank, world_size, 0); + } + + private: + const std::string store_prefix_; + size_t seq_id_ = 0; +}; + +} // namespace c10d::symmetric_memory