diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 62e618850bf..19d11a4939f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1041,6 +1041,8 @@ struct ggml_graph_node_properties { int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 84eccea3f7b..7b7ab107ef7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2911,6 +2911,10 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p } for (int i = 0; i < GGML_MAX_SRC; i++) { graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + for (int j = 0; j < GGML_MAX_DIMS; j++) { + graph_node_properties->src_ne[i][j] = node->src[i] ? node->src[i]->ne[j] : 0; + graph_node_properties->src_nb[i][j] = node->src[i] ? node->src[i]->nb[j] : 0; + } } memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); } @@ -2935,12 +2939,22 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && + if (!node->src[i]) { + continue; + } + if (node->src[i]->data != graph_node_properties->src_address[i] && node->op != GGML_OP_VIEW ) { return false; } + for (int j = 0; j < GGML_MAX_DIMS; j++) { + if (node->src[i]->ne[j] != graph_node_properties->src_ne[i][j]) { + return false; + } + if (node->src[i]->nb[j] != graph_node_properties->src_nb[i][j]) { + return false; + } + } } if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && @@ -2952,7 +2966,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - bool cuda_graph_update_required = false; if (cuda_ctx->cuda_graph->instance == nullptr) {