-
Notifications
You must be signed in to change notification settings - Fork 15.6k
cuda : fix nkvo, offload and cuda graph node properties matching #19165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const | |
| } | ||
| } | ||
|
|
||
| const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); | ||
|
|
||
| const int cc = ggml_cuda_info().devices[device].cc; | ||
|
|
||
| switch (K->ne[0]) { | ||
|
|
@@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const | |
| if (!gqa_opt_applies) { | ||
| return BEST_FATTN_KERNEL_NONE; | ||
| } | ||
| if (!V_is_K_view) { | ||
| return BEST_FATTN_KERNEL_NONE; | ||
| } | ||
|
Comment on lines
-337
to
-339
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As of right now the MMA kernel with a head size of 576/512 is only compiled for the assumption of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, got it. I'm currently looking into this and seems like #18934 introduced a bug, hence it is taking me some time to figure out a good solution. Basically, partial offload (e.g. with ./bin/llama-perplexity -m ~/models/ggml-org_gpt-oss-20b-GGUF_gpt-oss-20b-mxfp4.gguf -f ../build-cuda/wikitext-2-raw/wiki.test.raw -fa on -ngl 10I've traced it down to the properties matching logic incorrectly determining that the flash attention properties match. The net effect is, whenever the FA operator is being offloaded from the CPU to the GPU, the properties matching currently breaks.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think the issue that I described should be fixed now (see OP for more info). Regarding the check, I think it's OK to remove it for now. The check does not work correctly in the case where the FA op is offloaded from the CPU to the GPU, because in this case, we copy the K and V tensors separately and hence the V is no longer a view of K. Technically there is indeed a failure case if one has to ever use non-MLA FA with DK=576 and DV=512, but as of now we don't have such use cases. Let me know if this makes sense and you agree that it's OK to proceed with this. |
||
| break; | ||
| default: | ||
| return BEST_FATTN_KERNEL_NONE; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,17 +70,18 @@ | |
| #include <condition_variable> | ||
| #include <cstddef> | ||
| #include <cstdint> | ||
| #include <float.h> | ||
| #include <cfloat> | ||
| #include <initializer_list> | ||
| #include <limits> | ||
| #include <map> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <stdarg.h> | ||
| #include <stdio.h> | ||
| #include <stdlib.h> | ||
| #include <cstdarg> | ||
| #include <cstdio> | ||
| #include <cstdlib> | ||
| #include <string> | ||
| #include <vector> | ||
| #include <unordered_set> | ||
|
|
||
| static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | ||
|
|
||
|
|
@@ -2916,22 +2917,26 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { | |
| } | ||
|
|
||
| static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { | ||
| props->node_address = node->data; | ||
| memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); | ||
| props->node_data = node->data; | ||
| props->node_op = node->op; | ||
| props->flags = node->flags; | ||
| for (int i = 0; i < GGML_MAX_DIMS; i++) { | ||
| props->ne[i] = node->ne[i]; | ||
| props->nb[i] = node->nb[i]; | ||
| } | ||
| for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; | ||
| if (!node->src[i]) { | ||
| continue; | ||
| } | ||
|
|
||
| props->src_data[i] = node->src[i]->data; | ||
| } | ||
| memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); | ||
| } | ||
|
|
||
| static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { | ||
| if (node->data != props->node_address && | ||
| node->op != GGML_OP_VIEW) { | ||
| if (node->data != props->node_data && node->op != GGML_OP_VIEW) { | ||
| return false; | ||
| } | ||
|
|
||
|
|
@@ -2948,12 +2953,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ | |
| } | ||
| } | ||
|
|
||
| for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| if (node->src[i] && | ||
| node->src[i]->data != props->src_address[i] && | ||
| node->op != GGML_OP_VIEW | ||
| ) { | ||
| return false; | ||
| if (node->op != GGML_OP_VIEW) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that the code on master is doing this either, but shouldn't there be an
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| if (!node->src[i]) { | ||
| if (props->src_data[i] != nullptr) { | ||
| return false; | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| if (node->src[i]->data != props->src_data[i]) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -2974,7 +2985,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { | |
| } | ||
|
|
||
| static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { | ||
|
|
||
| bool res = false; | ||
|
|
||
| const void * graph_key = ggml_cuda_graph_get_key(cgraph); | ||
|
|
@@ -2985,33 +2995,52 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx | |
| } | ||
|
|
||
| // Check if the graph size has changed | ||
| if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { | ||
| if (graph->props.size() != (size_t)cgraph->n_nodes) { | ||
| res = true; | ||
| graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); | ||
| graph->props.resize(cgraph->n_nodes); | ||
| } | ||
|
|
||
| // Loop over nodes in GGML graph to determine if CUDA graph update is required | ||
| // and store properties to allow this comparison for the next token | ||
| std::unordered_set<ggml_tensor *> seen_node; | ||
| std::vector<ggml_tensor *> srcs_extra; | ||
| for (int i = 0; i < cgraph->n_nodes; i++) { | ||
| bool props_match = true; | ||
|
|
||
| seen_node.insert(cgraph->nodes[i]); | ||
|
|
||
| if (!res) { | ||
| props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); | ||
| } | ||
| if (!props_match) { | ||
| res = true; | ||
| } | ||
| ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); | ||
|
|
||
| for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { | ||
| ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; | ||
| if (src && seen_node.find(src) == seen_node.end()) { | ||
| srcs_extra.push_back(src); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (graph->extra.size() != (size_t) srcs_extra.size()) { | ||
| res = true; | ||
| graph->extra.resize(srcs_extra.size()); | ||
| } | ||
|
|
||
| for (int i = 0; i < cgraph->n_leafs; i++) { | ||
| for (size_t i = 0; i < srcs_extra.size(); ++i) { | ||
| bool props_match = true; | ||
|
|
||
| if (!res) { | ||
| props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]); | ||
| props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); | ||
| } | ||
|
|
||
| if (!props_match) { | ||
| res = true; | ||
| } | ||
| ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); | ||
| ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); | ||
| } | ||
|
|
||
| return res; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be useful to add a brief comment explaining what this is used for.