Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ggml/src/ggml-openvino/ggml-openvino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,11 @@ static ggml_guid_t ggml_backend_openvino_guid(void) {
return &guid;
}

std::shared_ptr<ov_runtime_context> get_ov_runtime_context_ptr() {
static std::shared_ptr<ov_runtime_context> r_ctx = std::make_shared<ov_runtime_context>();
return r_ctx;
}

// backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) {
if (device < 0 || device >= ggml_backend_openvino_get_device_count()) {
Expand All @@ -650,7 +655,7 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) {
return nullptr;
}

ctx->runtime_context = std::make_shared<ov_runtime_context>();
ctx->runtime_context = get_ov_runtime_context_ptr();
if (ctx->runtime_context == nullptr) {
GGML_LOG_ERROR("%s: failed to allocate runtime context\n", __func__);
delete ctx;
Expand Down
21 changes: 20 additions & 1 deletion ggml/src/ggml-openvino/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
auto states = infer_request->query_state();
for (auto state : states) {
auto state_tensor = state.get_state();
if (static_cast<uint32_t>(pos_data[0]) > r_ctx->stateful_kv_size) {
std::string state_name;
Comment thread
cavusmustafa marked this conversation as resolved.
Outdated
try {
state_name = r_ctx->kv_state_input_name_map.at(state.get_name());
} catch (...) {
GGML_LOG_ERROR("GGML OpenVINO backend stateful inference failed: no input found for the state\n");
return GGML_STATUS_FAILED;
}
auto kv_tensor = get_ov_input_tensor(ggml_decoder, state_name);
kv_tensor.set_shape({state_tensor.get_shape()[0], kv_tensor.get_shape()[2],
state_tensor.get_shape()[2], state_tensor.get_shape()[3]});
Comment thread
cavusmustafa marked this conversation as resolved.
Outdated
state_tensor = kv_tensor;
}
ov::Coordinate begin = {0, 0, 0, 0};
ov::Coordinate end = {state_tensor.get_shape()[0], static_cast<uint32_t>(pos_data[0]),
state_tensor.get_shape()[2], state_tensor.get_shape()[3]};
Expand Down Expand Up @@ -196,7 +209,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<
r_ctx->ov_output_names_cache[key] = std::move(ov_output_names);

if (stateful) {
r_ctx->stateful_kv_size = 0;
const auto * inp_pos = get_inp_pos_tensor(cgraph);
auto pos_shape = ggml_decoder->get_shape(inp_pos);
r_ctx->stateful_kv_size = pos_shape[3];
const auto kv_param_res_names = ggml_decoder->get_kv_param_res_names();
for (const auto& pair : kv_param_res_names) {
r_ctx->kv_state_input_name_map[pair.first+pair.second] = pair.first;
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-openvino/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ struct graph_key_hash {
struct ov_runtime_context {
std::string device;
bool stateful;
size_t stateful_kv_size;
std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;
std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;
std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_output_names_cache;
//TODO: Stateful is only supported for single request at a time.
// Simultanous stateful inference request support to be added.
size_t stateful_kv_size;
std::map<std::string, std::string> kv_state_input_name_map;

ov_runtime_context() :
device("CPU"),
Expand Down
Loading