Skip to content
Open
16 changes: 9 additions & 7 deletions src/core/agent_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ using nixl_comm_req_t = std::tuple<nixl_comm_t, std::string, int, nixl_blob_t>;

using nixl_socket_peer_t = std::pair<std::string, int>;

using nixl_socket_map_t = std::map<nixl_socket_peer_t, int>;

class nixlAgentData {
private:
std::string name;
Expand Down Expand Up @@ -91,13 +93,13 @@ class nixlAgentData {
std::hash<std::string>, strEqual> remoteSections;

// State/methods for listener thread
nixlMDStreamListener *listener;
std::map<nixl_socket_peer_t, int> remoteSockets;
std::thread commThread;
std::vector<nixl_comm_req_t> commQueue;
std::mutex commLock;
bool commThreadStop;
bool useEtcd;
nixlMDStreamListener *listener;
nixl_socket_map_t remoteSockets;
std::thread commThread;
std::vector<nixl_comm_req_t> commQueue;
std::mutex commLock;
bool commThreadStop;
bool useEtcd;
std::unique_ptr<nixlTelemetry> telemetry_;
void commWorker(nixlAgent* myAgent);
void enqueueCommWork(nixl_comm_req_t request);
Expand Down
70 changes: 57 additions & 13 deletions src/core/nixl_listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,7 @@ void nixlAgentData::commWorker(nixlAgent* myAgent){
nixl_socket_peer_t req_sock = std::make_pair(req_ip, req_port);

// use remote IP for socket lookup
const auto client = remoteSockets.find(req_sock);
int client_fd;
auto client = remoteSockets.find(req_sock);

// not connected
if (req_command < SOCK_MAX) {
Expand All @@ -507,23 +506,40 @@ void nixlAgentData::commWorker(nixlAgent* myAgent){
continue;
}
remoteSockets[req_sock] = new_client;
client_fd = new_client;
} else {
client_fd = client->second;
client = remoteSockets.find(req_sock);
}
}

bool needs_disconnect = false;
switch(req_command) {
case SOCK_SEND: {
sendCommMessage(client_fd, "NIXLCOMM:LOAD" + my_MD);
try {
sendCommMessage(client->second, "NIXLCOMM:LOAD" + my_MD);
}
catch (const std::runtime_error &e) {
NIXL_ERROR << "Failed to send message to peer, disconnecting: " << e.what();
needs_disconnect = true;
}
break;
}
case SOCK_FETCH: {
sendCommMessage(client_fd, "NIXLCOMM:SEND");
try {
sendCommMessage(client->second, "NIXLCOMM:SEND");
}
catch (const std::runtime_error &e) {
NIXL_ERROR << "Failed to send message to peer, disconnecting: " << e.what();
needs_disconnect = true;
}
break;
}
case SOCK_INVAL: {
sendCommMessage(client_fd, "NIXLCOMM:INVL" + name);
try {
sendCommMessage(client->second, "NIXLCOMM:INVL" + name);
}
catch (const std::runtime_error &e) {
NIXL_ERROR << "Failed to send message to peer, disconnecting: " << e.what();
needs_disconnect = true;
}
break;
}
#if HAVE_ETCD
Expand Down Expand Up @@ -596,6 +612,10 @@ void nixlAgentData::commWorker(nixlAgent* myAgent){
break;
}
}
if (needs_disconnect) {
close(client->second);
client = remoteSockets.erase(client);
}
}

// third, do remote commands
Expand All @@ -604,9 +624,21 @@ void nixlAgentData::commWorker(nixlAgent* myAgent){
std::string commands;
std::vector<std::string> command_list;
nixl_status_t ret;

if (!recvCommMessage(socket_iter->second, commands)) {
socket_iter++;
bool disconnected = false;

try {
const bool received = recvCommMessage(socket_iter->second, commands);
if (!received) {
// No message received, but without error condition.
// Skip to the next peer
socket_iter++;
continue;
}
}
catch (const std::runtime_error &e) {
NIXL_ERROR << "Failed to receive message from peer, disconnecting: " << e.what();
close(socket_iter->second);
socket_iter = remoteSockets.erase(socket_iter);
continue;
}

Expand Down Expand Up @@ -634,7 +666,14 @@ void nixlAgentData::commWorker(nixlAgent* myAgent){
nixl_blob_t my_MD;
myAgent->getLocalMD(my_MD);

sendCommMessage(socket_iter->second, std::string("NIXLCOMM:LOAD" + my_MD));
try {
sendCommMessage(socket_iter->second, std::string("NIXLCOMM:LOAD" + my_MD));
}
catch (const std::runtime_error &e) {
NIXL_ERROR << "Failed to send message to peer, disconnecting: " << e.what();
disconnected = true;
break;
}
} else if(header == "INVL") {
std::string remote_agent = command.substr(4);
myAgent->invalidateRemoteMD(remote_agent);
Expand All @@ -645,7 +684,12 @@ void nixlAgentData::commWorker(nixlAgent* myAgent){
}
}

socket_iter++;
if (disconnected) {
close(socket_iter->second);
socket_iter = remoteSockets.erase(socket_iter);
} else {
socket_iter++;
}
}

#if HAVE_ETCD
Expand Down