diff --git a/src/connection.cc b/src/connection.cc index fc3724c08..c2e5f5ba4 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -6,13 +6,14 @@ #if defined(ENABLE_NPKIT) #include #endif +#include + #include #include #include #include "debug.h" #include "endpoint.hpp" -#include "infiniband/verbs.h" namespace mscclpp { diff --git a/src/ib.cc b/src/ib.cc index 9955c5269..7d7a1b5eb 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -3,12 +3,12 @@ #include "ib.hpp" -#include #include #include #include #include +#include #include #include #include @@ -43,9 +43,9 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { } uintptr_t addr = reinterpret_cast(buff) & -pageSize; std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), pages * pageSize, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | - IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC); + this->mr = IBVerbs::ibv_reg_mr2(pd, reinterpret_cast(addr), pages * pageSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC); if (this->mr == nullptr) { std::stringstream err; err << "ibv_reg_mr failed (errno " << errno << ")"; @@ -54,7 +54,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { this->size = pages * pageSize; } -IbMr::~IbMr() { ibv_dereg_mr(this->mr); } +IbMr::~IbMr() { IBVerbs::ibv_dereg_mr(this->mr); } IbMrInfo IbMr::getInfo() const { IbMrInfo info; @@ -70,7 +70,7 @@ uint32_t IbMr::getLkey() const { return this->mr->lkey; } IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend) : numSignaledPostedItems(0), numSignaledStagedItems(0), maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) { - this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0); + this->cq = IBVerbs::ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0); if (this->cq == nullptr) { std::stringstream err; err << "ibv_create_cq failed (errno " << errno << ")"; @@ -89,7 +89,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN qpInitAttr.cap.max_recv_sge = 1; qpInitAttr.cap.max_inline_data = 0; - struct ibv_qp* _qp = ibv_create_qp(pd, &qpInitAttr); + struct ibv_qp* _qp = IBVerbs::ibv_create_qp(pd, &qpInitAttr); if (_qp == nullptr) { std::stringstream err; err << "ibv_create_qp failed (errno " << errno << ")"; @@ -97,7 +97,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN } struct ibv_port_attr portAttr; - if (ibv_query_port(ctx, port, &portAttr) != 0) { + if (IBVerbs::ibv_query_port_w(ctx, port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); @@ -111,7 +111,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND || this->info.is_grh) { union ibv_gid gid; - if (ibv_query_gid(ctx, port, 0, &gid) != 0) { + if (IBVerbs::ibv_query_gid(ctx, port, 0, &gid) != 0) { std::stringstream err; err << "ibv_query_gid failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); @@ -126,7 +126,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN qpAttr.pkey_index = 0; qpAttr.port_num = port; qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; - if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + if (IBVerbs::ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); @@ -139,8 +139,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN } IbQp::~IbQp() { - ibv_destroy_qp(this->qp); - ibv_destroy_cq(this->cq); + IBVerbs::ibv_destroy_qp(this->qp); + IBVerbs::ibv_destroy_cq(this->cq); } void IbQp::rtr(const IbQpInfo& info) { @@ -167,9 +167,9 @@ void IbQp::rtr(const IbQpInfo& info) { qp_attr.ah_attr.sl = 0; qp_attr.ah_attr.src_path_bits = 0; qp_attr.ah_attr.port_num = info.port; - int ret = ibv_modify_qp(this->qp, &qp_attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); + int ret = IBVerbs::ibv_modify_qp(this->qp, &qp_attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); if (ret != 0) { std::stringstream err; err << "ibv_modify_qp failed (errno " << errno << ")"; @@ -186,7 +186,7 @@ void IbQp::rts() { qp_attr.rnr_retry = 7; qp_attr.sq_psn = 0; qp_attr.max_rd_atomic = 1; - int ret = ibv_modify_qp( + int ret = IBVerbs::ibv_modify_qp( this->qp, &qp_attr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); if (ret != 0) { @@ -265,7 +265,7 @@ void IbQp::postSend() { return; } struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(this->qp, this->wrs.get(), &bad_wr); + int ret = IBVerbs::ibv_post_send(this->qp, this->wrs.get(), &bad_wr); if (ret != 0) { std::stringstream err; err << "ibv_post_send failed (errno " << errno << ")"; @@ -281,7 +281,7 @@ void IbQp::postSend() { } int IbQp::pollCq() { - int wcNum = ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); + int wcNum = IBVerbs::ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); if (wcNum > 0) { this->numSignaledPostedItems -= wcNum; } @@ -301,20 +301,20 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) { } #endif // !defined(__HIP_PLATFORM_AMD__) int num; - struct ibv_device** devices = ibv_get_device_list(&num); + struct ibv_device** devices = IBVerbs::ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { if (std::string(devices[i]->name) == devName) { - this->ctx = ibv_open_device(devices[i]); + this->ctx = IBVerbs::ibv_open_device(devices[i]); break; } } - ibv_free_device_list(devices); + IBVerbs::ibv_free_device_list(devices); if (this->ctx == nullptr) { std::stringstream err; err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")"; throw mscclpp::IbError(err.str(), errno); } - this->pd = ibv_alloc_pd(this->ctx); + this->pd = IBVerbs::ibv_alloc_pd(this->ctx); if (this->pd == nullptr) { std::stringstream err; err << "ibv_alloc_pd failed (errno " << errno << ")"; @@ -326,16 +326,16 @@ IbCtx::~IbCtx() { this->mrs.clear(); this->qps.clear(); if (this->pd != nullptr) { - ibv_dealloc_pd(this->pd); + IBVerbs::ibv_dealloc_pd(this->pd); } if (this->ctx != nullptr) { - ibv_close_device(this->ctx); + IBVerbs::ibv_close_device(this->ctx); } } bool IbCtx::isPortUsable(int port) const { struct ibv_port_attr portAttr; - if (ibv_query_port(this->ctx, port, &portAttr) != 0) { + if (IBVerbs::ibv_query_port_w(this->ctx, port, &portAttr) != 0) { std::stringstream err; err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")"; throw mscclpp::IbError(err.str(), errno); @@ -346,7 +346,7 @@ bool IbCtx::isPortUsable(int port) const { int IbCtx::getAnyActivePort() const { struct ibv_device_attr devAttr; - if (ibv_query_device(this->ctx, &devAttr) != 0) { + if (IBVerbs::ibv_query_device(this->ctx, &devAttr) != 0) { std::stringstream err; err << "ibv_query_device failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); @@ -382,7 +382,7 @@ const std::string& IbCtx::getDevName() const { return this->devName; } MSCCLPP_API_CPP int getIBDeviceCount() { int num; - ibv_get_device_list(&num); + IBVerbs::ibv_get_device_list(&num); return num; } @@ -441,7 +441,7 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) { } int num; - struct ibv_device** devices = ibv_get_device_list(&num); + struct ibv_device** devices = IBVerbs::ibv_get_device_list(&num); if (ibTransportIndex >= num) { std::stringstream ss; ss << "IB transport out of range: " << ibTransportIndex << " >= " << num; @@ -452,7 +452,7 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) { MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName) { int num; - struct ibv_device** devices = ibv_get_device_list(&num); + struct ibv_device** devices = IBVerbs::ibv_get_device_list(&num); for (int i = 0; i < num; ++i) { if (ibDeviceName == devices[i]->name) { switch (i) { // TODO: get rid of this ugly switch diff --git a/src/include/ibverbs_wrapper.hpp b/src/include/ibverbs_wrapper.hpp new file mode 100644 index 000000000..e862cbea3 --- /dev/null +++ b/src/include/ibverbs_wrapper.hpp @@ -0,0 +1,269 @@ +#ifndef MSCCLPP_IBVERBS_WRAPPER_HPP_ +#define MSCCLPP_IBVERBS_WRAPPER_HPP_ + +#include +#include + +#include +#include + +namespace mscclpp { + +struct IBVerbs { + private: + // Static method to initialize the library + static void initialize() { + initialized = true; + handle = dlopen("libibverbs.so", RTLD_NOW); + if (!handle) { + throw mscclpp::IbError("Failed to load libibverbs: " + std::string(dlerror()), errno); + } + + // Load the necessary functions + ibv_get_device_list_lib = (ibv_get_device_list_t)dlsym(handle, "ibv_get_device_list"); + ibv_free_device_list_lib = (ibv_free_device_list_t)dlsym(handle, "ibv_free_device_list"); + ibv_alloc_pd_lib = (ibv_alloc_pd_t)dlsym(handle, "ibv_alloc_pd"); + ibv_dealloc_pd_lib = (ibv_dealloc_pd_t)dlsym(handle, "ibv_dealloc_pd"); + ibv_open_device_lib = (ibv_open_device_t)dlsym(handle, "ibv_open_device"); + ibv_close_device_lib = (ibv_close_device_t)dlsym(handle, "ibv_close_device"); + ibv_query_device_lib = (ibv_query_device_t)dlsym(handle, "ibv_query_device"); + ibv_create_cq_lib = (ibv_create_cq_t)dlsym(handle, "ibv_create_cq"); + ibv_create_qp_lib = (ibv_create_qp_t)dlsym(handle, "ibv_create_qp"); + ibv_destroy_cq_lib = (ibv_destroy_cq_t)dlsym(handle, "ibv_destroy_cq"); + ibv_reg_mr_lib = (ibv_reg_mr_t)dlsym(handle, "ibv_reg_mr"); + ibv_dereg_mr_lib = (ibv_dereg_mr_t)dlsym(handle, "ibv_dereg_mr"); + ibv_query_gid_lib = (ibv_query_gid_t)dlsym(handle, "ibv_query_gid"); + ibv_modify_qp_lib = (ibv_modify_qp_t)dlsym(handle, "ibv_modify_qp"); + ibv_destroy_qp_lib = (ibv_destroy_qp_t)dlsym(handle, "ibv_destroy_qp"); + ibv_query_port_lib = (ibv_query_port_t)dlsym(handle, "ibv_query_port"); + ibv_reg_mr_iova2_lib = (ibv_reg_mr_iova2_t)dlsym(handle, "ibv_reg_mr_iova2"); + + if (!ibv_get_device_list_lib || !ibv_free_device_list_lib || !ibv_alloc_pd_lib || !ibv_dealloc_pd_lib || + !ibv_open_device_lib || !ibv_close_device_lib || !ibv_query_device_lib || !ibv_create_cq_lib || + !ibv_create_qp_lib || !ibv_destroy_cq_lib || !ibv_reg_mr_lib || !ibv_dereg_mr_lib || !ibv_query_gid_lib || + !ibv_reg_mr_iova2_lib || !ibv_modify_qp_lib || !ibv_destroy_qp_lib || !ibv_query_port_lib) { + throw mscclpp::IbError("Failed to load one or more function in the ibibverbs library: " + std::string(dlerror()), + errno); + dlclose(handle); + } + } + + public: + // Static method to get the device list + static struct ibv_device** ibv_get_device_list(int* num_devices) { + if (!initialized) initialize(); + if (ibv_get_device_list_lib) { + return ibv_get_device_list_lib(num_devices); + } + return nullptr; + } + + // Static method to free the device list + static void ibv_free_device_list(struct ibv_device** list) { + if (!initialized) initialize(); + if (ibv_free_device_list_lib) { + ibv_free_device_list_lib(list); + } + } + + // Static method to allocate a protection domain + static struct ibv_pd* ibv_alloc_pd(struct ibv_context* context) { + if (!initialized) initialize(); + if (ibv_alloc_pd_lib) { + return ibv_alloc_pd_lib(context); + } + return nullptr; + } + + // Static method to deallocate a protection domain + static int ibv_dealloc_pd(struct ibv_pd* pd) { + if (!initialized) initialize(); + if (ibv_dealloc_pd_lib) { + return ibv_dealloc_pd_lib(pd); + } + return -1; + } + + // Static method to open a device + static struct ibv_context* ibv_open_device(struct ibv_device* device) { + if (!initialized) initialize(); + if (ibv_open_device_lib) { + return ibv_open_device_lib(device); + } + return nullptr; + } + + // Static method to close a device + static int ibv_close_device(struct ibv_context* context) { + if (!initialized) initialize(); + if (ibv_close_device_lib) { + return ibv_close_device_lib(context); + } + return -1; + } + + // Static method to query a device + static int ibv_query_device(struct ibv_context* context, struct ibv_device_attr* device_attr) { + if (!initialized) initialize(); + if (ibv_query_device_lib) { + return ibv_query_device_lib(context, device_attr); + } + return -1; + } + + // Static method to create a completion queue + static struct ibv_cq* ibv_create_cq(struct ibv_context* context, int cqe, void* cq_context, + struct ibv_comp_channel* channel, int comp_vector) { + if (!initialized) initialize(); + if (ibv_create_cq_lib) { + return ibv_create_cq_lib(context, cqe, cq_context, channel, comp_vector); + } + return nullptr; + } + + // Static method to create a queue pair + static struct ibv_qp* ibv_create_qp(struct ibv_pd* pd, struct ibv_qp_init_attr* qp_init_attr) { + if (!initialized) initialize(); + if (ibv_create_qp_lib) { + return ibv_create_qp_lib(pd, qp_init_attr); + } + return nullptr; + } + + // Static method to destroy a completion queue + static int ibv_destroy_cq(struct ibv_cq* cq) { + if (!initialized) initialize(); + if (ibv_destroy_cq_lib) { + return ibv_destroy_cq_lib(cq); + } + return -1; + } + + // Static method to register a memory region + static struct ibv_mr* ibv_reg_mr2(struct ibv_pd* pd, void* addr, size_t length, int access) { + if (!initialized) initialize(); + if (ibv_reg_mr_lib) { + return ibv_reg_mr_lib(pd, addr, length, access); + } + return nullptr; + } + + // Static method to deregister a memory region + static int ibv_dereg_mr(struct ibv_mr* mr) { + if (!initialized) initialize(); + if (ibv_dereg_mr_lib) { + return ibv_dereg_mr_lib(mr); + } + return -1; + } + + // Static method to query a GID + static int ibv_query_gid(struct ibv_context* context, uint8_t port_num, int index, union ibv_gid* gid) { + if (!initialized) initialize(); + if (ibv_query_gid_lib) { + return ibv_query_gid_lib(context, port_num, index, gid); + } + return -1; + } + + // Static method to modify a queue pair + static int ibv_modify_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask) { + if (!initialized) initialize(); + if (ibv_modify_qp_lib) { + return ibv_modify_qp_lib(qp, attr, attr_mask); + } + return -1; + } + + // Static method to destroy a queue pair + static int ibv_destroy_qp(struct ibv_qp* qp) { + if (!initialized) initialize(); + if (ibv_destroy_qp_lib) { + return ibv_destroy_qp_lib(qp); + } + return -1; + } + + static inline int ibv_post_send(struct ibv_qp* qp, struct ibv_send_wr* wr, struct ibv_send_wr** bad_wr) { + if (!initialized) initialize(); + return qp->context->ops.post_send(qp, wr, bad_wr); + } + + static inline int ibv_poll_cq(struct ibv_cq* cq, int num_entries, struct ibv_wc* wc) { + if (!initialized) initialize(); + return cq->context->ops.poll_cq(cq, num_entries, wc); + } + + static int ibv_query_port_w(struct ibv_context* context, uint8_t port_num, struct ibv_port_attr* port_attr) { + if (!initialized) initialize(); + if (ibv_query_port_lib) { + return ibv_query_port_lib(context, port_num, port_attr); + } + return -1; + } + + static struct ibv_mr* ibv_reg_mr_iova2_w(struct ibv_pd* pd, void* addr, size_t length, uint64_t iova, + unsigned int access) { + if (!initialized) initialize(); + if (ibv_reg_mr_iova2_lib) { + return ibv_reg_mr_iova2_lib(pd, addr, length, iova, access); + } + return nullptr; + } + + // Static method to clean up + static void cleanup() { + if (handle) { + dlclose(handle); + handle = nullptr; + } + } + + private: + // Handle for the dynamic library + static inline void* handle = nullptr; + + // Function pointers + typedef struct ibv_device** (*ibv_get_device_list_t)(int*); + typedef void (*ibv_free_device_list_t)(struct ibv_device**); + typedef struct ibv_pd* (*ibv_alloc_pd_t)(struct ibv_context*); + typedef int (*ibv_dealloc_pd_t)(struct ibv_pd*); + typedef struct ibv_context* (*ibv_open_device_t)(struct ibv_device*); + typedef int (*ibv_close_device_t)(struct ibv_context*); + typedef int (*ibv_query_device_t)(struct ibv_context*, struct ibv_device_attr*); + typedef struct ibv_cq* (*ibv_create_cq_t)(struct ibv_context*, int, void*, struct ibv_comp_channel*, int); + typedef struct ibv_qp* (*ibv_create_qp_t)(struct ibv_pd*, struct ibv_qp_init_attr*); + typedef int (*ibv_destroy_cq_t)(struct ibv_cq*); + typedef int (*ibv_destroy_qp_t)(struct ibv_qp*); + typedef struct ibv_mr* (*ibv_reg_mr_t)(struct ibv_pd*, void*, size_t, int); + typedef int (*ibv_dereg_mr_t)(struct ibv_mr*); + typedef int (*ibv_query_gid_t)(struct ibv_context*, uint8_t, int, union ibv_gid*); + typedef int (*ibv_modify_qp_t)(struct ibv_qp*, struct ibv_qp_attr*, int); + typedef int (*ibv_query_port_t)(struct ibv_context*, uint8_t, struct ibv_port_attr*); + typedef struct ibv_mr* (*ibv_reg_mr_iova2_t)(struct ibv_pd* pd, void* addr, size_t length, uint64_t iova, + unsigned int access); + + static inline ibv_get_device_list_t ibv_get_device_list_lib; + static inline ibv_free_device_list_t ibv_free_device_list_lib = nullptr; + static inline ibv_alloc_pd_t ibv_alloc_pd_lib = nullptr; + static inline ibv_dealloc_pd_t ibv_dealloc_pd_lib = nullptr; + static inline ibv_open_device_t ibv_open_device_lib = nullptr; + static inline ibv_close_device_t ibv_close_device_lib = nullptr; + static inline ibv_query_device_t ibv_query_device_lib = nullptr; + static inline ibv_create_cq_t ibv_create_cq_lib = nullptr; + static inline ibv_create_qp_t ibv_create_qp_lib = nullptr; + static inline ibv_destroy_cq_t ibv_destroy_cq_lib = nullptr; + static inline ibv_reg_mr_t ibv_reg_mr_lib = nullptr; + static inline ibv_dereg_mr_t ibv_dereg_mr_lib = nullptr; + static inline ibv_query_gid_t ibv_query_gid_lib = nullptr; + static inline ibv_modify_qp_t ibv_modify_qp_lib = nullptr; + static inline ibv_destroy_qp_t ibv_destroy_qp_lib = nullptr; + static inline ibv_query_port_t ibv_query_port_lib = nullptr; + static inline ibv_reg_mr_iova2_t ibv_reg_mr_iova2_lib = nullptr; + + static inline bool initialized = false; +}; + +} // namespace mscclpp + +#endif // MSCCLPP_IBVERBS_WRAPPER_HPP_ \ No newline at end of file