Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3696,6 +3696,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}

static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {

#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);

Expand Down Expand Up @@ -3736,17 +3737,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;

ggml_cuda_set_device(cuda_ctx->device);

bool use_cuda_graph = false;
bool cuda_graph_update_required = false;

// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
// we call it here instead.
#ifdef USE_CUDA_GRAPH
if (!cuda_ctx->cuda_graph) {
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
} else {
use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
}
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);

if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
Expand All @@ -3762,6 +3761,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,

if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
Expand Down
Loading