Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,8 @@ struct ggml_graph_node_properties {
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};

Expand Down
19 changes: 16 additions & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2911,6 +2911,10 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
for (int j = 0; j < GGML_MAX_DIMS; j++) {
graph_node_properties->src_ne[i][j] = node->src[i] ? node->src[i]->ne[j] : 0;
graph_node_properties->src_nb[i][j] = node->src[i] ? node->src[i]->nb[j] : 0;
}
}
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
Expand All @@ -2935,12 +2939,22 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
}

for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
if (!node->src[i]) {
continue;
}
if (node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
}
for (int j = 0; j < GGML_MAX_DIMS; j++) {
if (node->src[i]->ne[j] != graph_node_properties->src_ne[i][j]) {
return false;
}
if (node->src[i]->nb[j] != graph_node_properties->src_nb[i][j]) {
return false;
}
}
}

if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
Expand All @@ -2952,7 +2966,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
}

static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {

bool cuda_graph_update_required = false;

if (cuda_ctx->cuda_graph->instance == nullptr) {
Expand Down
Loading