Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1149,8 +1149,8 @@ struct ggml_cuda_graph {
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
int number_consecutive_updates = 0;
bool disable_due_to_too_many_rebuilds = false;
int number_consecutive_rebuilds = 0;
std::vector<ggml_cuda_graph_node_properties> props;

// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
Expand All @@ -1161,19 +1161,19 @@ struct ggml_cuda_graph {

void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
number_consecutive_rebuilds++;
} else {
number_consecutive_updates = 0;
number_consecutive_rebuilds = 0;
}
if (number_consecutive_updates >= 4) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
disable_due_to_too_many_updates = true;
if (number_consecutive_rebuilds >= 4) {
GGML_LOG_INFO("%s: disabling CUDA graphs due to too many consecutive rebuilds\n", __func__);
disable_due_to_too_many_rebuilds = true;
}
}

bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_rebuilds);
}
#endif
};
Expand Down
11 changes: 5 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3048,17 +3048,18 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
#endif // CUDART_VERSION >= 12000

if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#endif

//#ifndef NDEBUG
// GGML_LOG_INFO("%s: CUDA graph update failed due to %d\n", __func__, static_cast<int>(result_info));
//#endif
graph->record_update(true, true);
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
(void)cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
} else {
graph->record_update(true, false);
GGML_ASSERT(stat == cudaSuccess);
}
}
Expand Down Expand Up @@ -3937,8 +3938,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if (graph->is_enabled()) {
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);

graph->record_update(use_cuda_graph, cuda_graph_update_required);
}
#endif // USE_CUDA_GRAPH

Expand Down
Loading