Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: check op supporting #11923

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 102 additions & 4 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ enum rpc_cmd {
RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_SUPPORT_OP,
RPC_CMD_COUNT,
};

Expand Down Expand Up @@ -158,6 +159,11 @@ struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};

struct rpc_msg_support_op_rsp {
uint8_t result;
};

#pragma pack(pop)

// RPC data structures
Expand Down Expand Up @@ -438,7 +444,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
rpc_tensor result;
result.id = reinterpret_cast<uint64_t>(tensor);
result.type = tensor->type;
if (tensor->buffer) {
if (tensor->buffer && tensor->buffer->context) {
ggml_backend_buffer_t buffer = tensor->buffer;
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
result.buffer = ctx->remote_ptr;
Expand Down Expand Up @@ -767,6 +773,31 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
get_device_memory(sock, free, total);
}

static bool ggml_backend_rpc_support_op(const char * endpoint, const ggml_tensor * tensor) {
std::vector<uint8_t> input;
{
std::vector<rpc_tensor> tensors;
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i] == nullptr) {
break;
}
tensors.push_back(serialize_tensor(tensor->src[i]));
}
tensors.push_back(serialize_tensor(tensor));
// serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
uint32_t n_tensors = tensors.size();
int input_size = sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
input.resize(input_size, 0);
memcpy(input.data(), &n_tensors, sizeof(n_tensors));
memcpy(input.data() + sizeof(n_tensors), tensors.data(), n_tensors * sizeof(rpc_tensor));
}
rpc_msg_support_op_rsp response;
auto sock = get_socket(endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_SUPPORT_OP, input.data(), input.size(), &response, sizeof(response));
GGML_ASSERT(status);
return response.result;
}

// RPC server-side implementation

class rpc_server {
Expand All @@ -786,6 +817,7 @@ class rpc_server {
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool support_op(const std::vector<uint8_t> & input, rpc_msg_support_op_rsp & response);

private:
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
Expand Down Expand Up @@ -829,6 +861,42 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
return true;
}

bool rpc_server::support_op(const std::vector<uint8_t> & input, rpc_msg_support_op_rsp & response) {
// serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < sizeof(uint32_t)) {
GGML_LOG_ERROR("[%s] invalid input size\n", __func__);
return false;
}
uint32_t n_tensors;
memcpy(&n_tensors, input.data(), sizeof(n_tensors));
if (input.size() < sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor)) {
GGML_LOG_ERROR("[%s] invalid input size\n", __func__);
return false;
}
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(uint32_t));
GGML_PRINT_DEBUG("[%s] n_tensors: %u\n", __func__, n_tensors);

size_t buf_size = ggml_tensor_overhead()*n_tensors;
struct ggml_init_params params {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx = ggml_init(params);
ggml_tensor * tensor = deserialize_tensor(ctx, &tensors[n_tensors-1]);
for (uint32_t i = 0; i < n_tensors-1; i++) {
ggml_tensor * src = deserialize_tensor(ctx, &tensors[i]);
tensor->src[i] = src;
}
response.result = true;
if (backend->device->iface.supports_op) {
response.result = backend->device->iface.supports_op(backend->device, tensor);
}
ggml_free(ctx);

return true;
}

void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
Expand Down Expand Up @@ -1326,6 +1394,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
}
break;
}
case RPC_CMD_SUPPORT_OP: {
std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_support_op_rsp response;
if (!server.support_op(input, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
default: {
fprintf(stderr, "Unknown command: %d\n", cmd);
return;
Expand Down Expand Up @@ -1436,10 +1518,26 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b
}

static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
static std::unordered_map<std::string, std::unordered_map<std::string, bool>> caches;
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;

auto &cache = caches[ctx->endpoint];
std::string key = op->name;
key += std::to_string(op->type);
for (int i = 0; i < GGML_MAX_DIMS; i++) {
key += std::to_string(op->ne[i]);
}
key += std::to_string(op->op);
Comment on lines +1525 to +1530
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be enough to distinguish one tensor?


auto it = cache.find(key);
if (it != cache.end()) {
return it->second;
}
bool result = ggml_backend_rpc_support_op(ctx->endpoint.c_str(), op);
cache[key] = result;
return result;

GGML_UNUSED(dev);
GGML_UNUSED(op);
//TODO: call the remote backend and cache the results
return true;
}

static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
Expand Down
Loading