Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically load libibverbs #337

Merged
merged 11 commits into from
Aug 14, 2024
3 changes: 2 additions & 1 deletion src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
#if defined(ENABLE_NPKIT)
#include <mscclpp/npkit/npkit.hpp>
#endif
#include <infiniband/verbs.h>

#include <mscclpp/utils.hpp>
#include <sstream>
#include <thread>

#include "debug.h"
#include "endpoint.hpp"
#include "infiniband/verbs.h"

namespace mscclpp {

Expand Down
58 changes: 29 additions & 29 deletions src/ib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

#include "ib.hpp"

#include <infiniband/verbs.h>
#include <malloc.h>
#include <unistd.h>

#include <cstring>
#include <fstream>
#include <ibverbs_wrapper.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/fifo.hpp>
#include <sstream>
Expand Down Expand Up @@ -43,9 +43,9 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
}
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
this->mr = ibv_reg_mr(pd, reinterpret_cast<void*>(addr), pages * pageSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
this->mr = IBVerbs::ibv_reg_mr2(pd, reinterpret_cast<void*>(addr), pages * pageSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
if (this->mr == nullptr) {
std::stringstream err;
err << "ibv_reg_mr failed (errno " << errno << ")";
Expand All @@ -54,7 +54,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
this->size = pages * pageSize;
}

IbMr::~IbMr() { ibv_dereg_mr(this->mr); }
IbMr::~IbMr() { IBVerbs::ibv_dereg_mr(this->mr); }

IbMrInfo IbMr::getInfo() const {
IbMrInfo info;
Expand All @@ -70,7 +70,7 @@ uint32_t IbMr::getLkey() const { return this->mr->lkey; }
IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr,
int maxWrPerSend)
: numSignaledPostedItems(0), numSignaledStagedItems(0), maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) {
this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0);
this->cq = IBVerbs::ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0);
if (this->cq == nullptr) {
std::stringstream err;
err << "ibv_create_cq failed (errno " << errno << ")";
Expand All @@ -89,15 +89,15 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
qpInitAttr.cap.max_recv_sge = 1;
qpInitAttr.cap.max_inline_data = 0;

struct ibv_qp* _qp = ibv_create_qp(pd, &qpInitAttr);
struct ibv_qp* _qp = IBVerbs::ibv_create_qp(pd, &qpInitAttr);
if (_qp == nullptr) {
std::stringstream err;
err << "ibv_create_qp failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}

struct ibv_port_attr portAttr;
if (ibv_query_port(ctx, port, &portAttr) != 0) {
if (IBVerbs::ibv_query_port_w(ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
Expand All @@ -111,7 +111,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN

if (portAttr.link_layer != IBV_LINK_LAYER_INFINIBAND || this->info.is_grh) {
union ibv_gid gid;
if (ibv_query_gid(ctx, port, 0, &gid) != 0) {
if (IBVerbs::ibv_query_gid(ctx, port, 0, &gid) != 0) {
std::stringstream err;
err << "ibv_query_gid failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
Expand All @@ -126,7 +126,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
qpAttr.pkey_index = 0;
qpAttr.port_num = port;
qpAttr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC;
if (ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
if (IBVerbs::ibv_modify_qp(_qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
Expand All @@ -139,8 +139,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN
}

IbQp::~IbQp() {
ibv_destroy_qp(this->qp);
ibv_destroy_cq(this->cq);
IBVerbs::ibv_destroy_qp(this->qp);
IBVerbs::ibv_destroy_cq(this->cq);
}

void IbQp::rtr(const IbQpInfo& info) {
Expand All @@ -167,9 +167,9 @@ void IbQp::rtr(const IbQpInfo& info) {
qp_attr.ah_attr.sl = 0;
qp_attr.ah_attr.src_path_bits = 0;
qp_attr.ah_attr.port_num = info.port;
int ret = ibv_modify_qp(this->qp, &qp_attr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
int ret = IBVerbs::ibv_modify_qp(this->qp, &qp_attr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER);
if (ret != 0) {
std::stringstream err;
err << "ibv_modify_qp failed (errno " << errno << ")";
Expand All @@ -186,7 +186,7 @@ void IbQp::rts() {
qp_attr.rnr_retry = 7;
qp_attr.sq_psn = 0;
qp_attr.max_rd_atomic = 1;
int ret = ibv_modify_qp(
int ret = IBVerbs::ibv_modify_qp(
this->qp, &qp_attr,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC);
if (ret != 0) {
Expand Down Expand Up @@ -265,7 +265,7 @@ void IbQp::postSend() {
return;
}
struct ibv_send_wr* bad_wr;
int ret = ibv_post_send(this->qp, this->wrs.get(), &bad_wr);
int ret = IBVerbs::ibv_post_send(this->qp, this->wrs.get(), &bad_wr);
if (ret != 0) {
std::stringstream err;
err << "ibv_post_send failed (errno " << errno << ")";
Expand All @@ -281,7 +281,7 @@ void IbQp::postSend() {
}

int IbQp::pollCq() {
int wcNum = ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get());
int wcNum = IBVerbs::ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get());
if (wcNum > 0) {
this->numSignaledPostedItems -= wcNum;
}
Expand All @@ -301,20 +301,20 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) {
}
#endif // !defined(__HIP_PLATFORM_AMD__)
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
struct ibv_device** devices = IBVerbs::ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (std::string(devices[i]->name) == devName) {
this->ctx = ibv_open_device(devices[i]);
this->ctx = IBVerbs::ibv_open_device(devices[i]);
break;
}
}
ibv_free_device_list(devices);
IBVerbs::ibv_free_device_list(devices);
if (this->ctx == nullptr) {
std::stringstream err;
err << "ibv_open_device failed (errno " << errno << ", device name << " << devName << ")";
throw mscclpp::IbError(err.str(), errno);
}
this->pd = ibv_alloc_pd(this->ctx);
this->pd = IBVerbs::ibv_alloc_pd(this->ctx);
if (this->pd == nullptr) {
std::stringstream err;
err << "ibv_alloc_pd failed (errno " << errno << ")";
Expand All @@ -326,16 +326,16 @@ IbCtx::~IbCtx() {
this->mrs.clear();
this->qps.clear();
if (this->pd != nullptr) {
ibv_dealloc_pd(this->pd);
IBVerbs::ibv_dealloc_pd(this->pd);
}
if (this->ctx != nullptr) {
ibv_close_device(this->ctx);
IBVerbs::ibv_close_device(this->ctx);
}
}

bool IbCtx::isPortUsable(int port) const {
struct ibv_port_attr portAttr;
if (ibv_query_port(this->ctx, port, &portAttr) != 0) {
if (IBVerbs::ibv_query_port_w(this->ctx, port, &portAttr) != 0) {
std::stringstream err;
err << "ibv_query_port failed (errno " << errno << ", port << " << port << ")";
throw mscclpp::IbError(err.str(), errno);
Expand All @@ -346,7 +346,7 @@ bool IbCtx::isPortUsable(int port) const {

int IbCtx::getAnyActivePort() const {
struct ibv_device_attr devAttr;
if (ibv_query_device(this->ctx, &devAttr) != 0) {
if (IBVerbs::ibv_query_device(this->ctx, &devAttr) != 0) {
std::stringstream err;
err << "ibv_query_device failed (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
Expand Down Expand Up @@ -382,7 +382,7 @@ const std::string& IbCtx::getDevName() const { return this->devName; }

MSCCLPP_API_CPP int getIBDeviceCount() {
int num;
ibv_get_device_list(&num);
IBVerbs::ibv_get_device_list(&num);
return num;
}

Expand Down Expand Up @@ -441,7 +441,7 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) {
}

int num;
struct ibv_device** devices = ibv_get_device_list(&num);
struct ibv_device** devices = IBVerbs::ibv_get_device_list(&num);
if (ibTransportIndex >= num) {
std::stringstream ss;
ss << "IB transport out of range: " << ibTransportIndex << " >= " << num;
Expand All @@ -452,7 +452,7 @@ MSCCLPP_API_CPP std::string getIBDeviceName(Transport ibTransport) {

MSCCLPP_API_CPP Transport getIBTransportByDeviceName(const std::string& ibDeviceName) {
int num;
struct ibv_device** devices = ibv_get_device_list(&num);
struct ibv_device** devices = IBVerbs::ibv_get_device_list(&num);
for (int i = 0; i < num; ++i) {
if (ibDeviceName == devices[i]->name) {
switch (i) { // TODO: get rid of this ugly switch
Expand Down
Loading
Loading