diff --git a/src/ggml-cuda/common.cuh b/src/ggml-cuda/common.cuh index 3aec1742ee..e2121e7ede 100644 --- a/src/ggml-cuda/common.cuh +++ b/src/ggml-cuda/common.cuh @@ -1194,7 +1194,12 @@ struct ggml_cuda_graph { bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env); + // Disable graphs when the per-op perf logger is on: graph capture + // would either hide individual-op timings inside cudaGraphLaunch + // or re-record over still-pending events on subsequent launches. + // See ggml-cuda.cu's ggml_cuda_perf_logger comment for context. + static const bool disable_cuda_graphs_due_to_perf_logger = (getenv("GGML_CUDA_PERF_LOGGER") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_cuda_graphs_due_to_perf_logger); } #endif }; @@ -1467,11 +1472,23 @@ struct ggml_cuda_mm_fusion_args_host { const ggml_tensor * x_bias = nullptr; const ggml_tensor * gate = nullptr; const ggml_tensor * gate_bias = nullptr; + // Residual tensor added to the matmul output AFTER bias and (if any) gate. + // When both x_bias and x_residual are set the kernel performs + // dst = mat * y + bias + residual + // in a single dispatch, mirroring ggml-vulkan's MUL_MAT_ADD_ADD shader and + // saving the launch overhead of a stand-alone GGML_OP_ADD per residual + // connection. Used by the 3-op MUL_MAT + ADD(bias) + ADD(residual) fusion + // detected in ggml_backend_cuda_graph_compute. Must have ne[0] == + // dst->ne[0] and the same shape as the bias-add output (no broadcasting). + // Set to nullptr for normal 2-op MUL_MAT + ADD(bias) fusion or unfused + // dispatch. + const ggml_tensor * x_residual = nullptr; ggml_glu_op glu_op; }; struct ggml_cuda_mm_fusion_args_device { const void * x_bias = nullptr; const void * gate = nullptr; const void * gate_bias = nullptr; + const void * x_residual = nullptr; // see _host counterpart for semantics ggml_glu_op glu_op; }; diff --git a/src/ggml-cuda/conv-transpose-1d.cu b/src/ggml-cuda/conv-transpose-1d.cu index 8418ba6673..ca71ef748b 100644 --- a/src/ggml-cuda/conv-transpose-1d.cu +++ b/src/ggml-cuda/conv-transpose-1d.cu @@ -1,57 +1,110 @@ #include "conv-transpose-1d.cuh" -static __global__ void conv_transpose_1d_kernel( +// One CUDA warp (32 threads) cooperatively computes one output pixel +// dst[oc, ol] (== dst[ol + OL*oc] in linear index, since we keep ne2/ne3 == 1). +// +// Grid : (OL, OC, 1) +// Block: (32, 1, 1) — exactly one warp; sized below as CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE. +// +// Two perf-critical changes vs the original "1 thread per output pixel + scan +// the full IC*IL grid + skip via conditional" implementation: +// +// 1. Narrow the input position i to the small range that actually +// contributes: +// out[ol, oc] = sum over (ic, i, ki) of k[ki, oc, ic] * x[i, ic] +// subject to i*s0 + ki == ol, 0 <= ki < K, 0 <= i < IL +// ⇒ i ∈ [ ceil((ol - K + 1)/s0), floor(ol/s0) ] ∩ [0, IL-1] +// typically (KS=16, s0=8) this is 2 iterations of i instead of IL=O(100). +// +// 2. Parallelise the IC reduction across the warp (each thread handles a +// strided slice of IC) and finalise with __shfl_xor_sync. This gives +// 32× useful work per warp on top of the i-range narrowing. +// +// Layouts (matching the original kernel and the Vulkan / Metal patches): +// src0 (kernel) : [K, OC, IC] row-major → element (ki, oc, ic) at +// ic*(OC*K) + oc*K + ki +// src1 (input) : [IL, IC] row-major → element (i, ic) at ic*IL + i +// dst : [OL, OC] row-major → element (ol, oc) at oc*OL + ol +// +// Limitation (unchanged from the original kernel): only ne1==ne3==1 is +// supported; the host-side wrapper enforces that via the contiguous + +// shape assertions. +static __global__ void conv_transpose_1d_kernel( const int s0, const int p0, const int d0, const int output_size, const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const float * src0, const float * src1, float * dst) { - int global_index = threadIdx.x + blockIdx.x * blockDim.x; - if (global_index >= output_size) { + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * __restrict__ src0, const float * __restrict__ src1, float * __restrict__ dst) { + + const int ol = blockIdx.x; + const int oc = blockIdx.y; + if (ol >= dst_ne0 || oc >= dst_ne1) { return; } - int out_index = global_index / dst_ne0; - - float accumulator = 0; - - for (int c = 0; c < src0_ne2; c++) { - int idx = global_index % dst_ne0; - - int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); - int input_offset = src1_ne0 * c; - - for (int i = 0; i < src1_ne0; i++) { - if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) { - continue; - } - int weight_idx = idx - i*s0; + const int K = src0_ne0; + const int OC = dst_ne1; + const int IC = src0_ne2; + const int IL = src1_ne0; + + // Range of input positions i that contribute to this output pixel. + int i_start = (ol - K + 1 + s0 - 1) / s0; // ceil((ol - K + 1) / s0) + if (i_start < 0) i_start = 0; + int i_end = ol / s0; + if (i_end > IL - 1) i_end = IL - 1; + + const int tid = threadIdx.x; + const int nth = blockDim.x; + + float v = 0.0f; + + // Each thread handles a strided slice of IC; the range of i is + // already narrow (≤ K/s0 + 1), so the inner loop is the cheap one. + for (int ic = tid; ic < IC; ic += nth) { + const int kernel_base = (ic * OC + oc) * K; + const int input_base = ic * IL; + #pragma unroll 4 + for (int i = i_start; i <= i_end; ++i) { + const int ki = ol - i * s0; + v += src0[kernel_base + ki] * src1[input_base + i]; + } + } - float kernel_weight = src0[kernel_offset + weight_idx]; - float input_value = src1[input_offset+i]; + // Reduce across the warp. + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + v += __shfl_xor_sync(0xFFFFFFFFu, v, offset); + } - accumulator += kernel_weight * input_value; - } + if (tid == 0) { + dst[oc * dst_ne0 + ol] = v; } - dst[global_index] = accumulator; - GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2); + + GGML_UNUSED_VARS(p0, d0, output_size, + src0_ne1, src0_ne3, src1_ne1, src1_ne2, src1_ne3, + dst_ne2, dst_ne3); } static void conv_transpose_1d_f32_f32_cuda( const int s0, const int p0, const int d0, const int output_size, const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, - const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, - const float * src0, const float * src1, float * dst, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst, cudaStream_t stream) { - const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE; - conv_transpose_1d_kernel<<>>( - s0,p0,d0,output_size, - src0_ne0, src0_ne1, src0_ne2, src0_ne3, - src1_ne0, src1_ne1, src1_ne2, src1_ne3, - dst_ne0, dst_ne1, dst_ne2, dst_ne3, - src0,src1, dst); + // Block = one warp (32 threads). Grid has one block per output pixel, + // i.e. (OL, OC). ne2/ne3 are required to be 1 by the existing host-side + // assertions, so we don't extend the grid into z. + const dim3 block_dim(CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 1, 1); + const dim3 grid_dim((unsigned)dst_ne0, (unsigned)dst_ne1, 1); + + conv_transpose_1d_kernel<<>>( + s0, p0, d0, output_size, + src0_ne0, src0_ne1, src0_ne2, src0_ne3, + src1_ne0, src1_ne1, src1_ne2, src1_ne3, + dst_ne0, dst_ne1, dst_ne2, dst_ne3, + src0, src1, dst); } void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/src/ggml-cuda/conv-transpose-1d.cuh b/src/ggml-cuda/conv-transpose-1d.cuh index 6c2cf666b6..618b090059 100644 --- a/src/ggml-cuda/conv-transpose-1d.cuh +++ b/src/ggml-cuda/conv-transpose-1d.cuh @@ -1,5 +1,6 @@ #include "common.cuh" -#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256 +// One warp per output pixel; see conv-transpose-1d.cu for why. +#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 32 void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index 185956317e..4b86a5a457 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -81,10 +81,264 @@ #include #include #include +#include #include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +// ============================================================================= +// Per-op GPU timing logger (mirrors ggml-vulkan's GGML_VK_PERF_LOGGER=1). +// +// Enabled by setting GGML_CUDA_PERF_LOGGER=1. Prints aggregate per-op GPU +// time + dispatch count after every ggml_backend_cuda_graph_compute() call. +// Output format intentionally matches ggml-vulkan's +// "----------------\nVulkan Timings:" so existing grep / awk one-liners work +// for both backends. +// +// Implementation: +// - Uses cudaEventRecord pairs around each per-op dispatch; one timing +// event slot per dispatched op, ~3 µs of overhead per op. +// - cudaStreamSynchronize is forced at the end of each compute_graph call +// so elapsed times are readable before any subsequent re-record. +// - When the env var is set, CUDA Graphs are disabled (graphs would +// either re-record over still-pending events on subsequent launches, +// or hide per-op timing inside cudaGraphLaunch). +// +// Aggregation key encodes op + dtype + shape so repeated calls at the same +// shape are grouped (matching Vulkan's vk_perf_logger behaviour). +// +// Thread safety: ggml-cuda's compute path is single-threaded per context; +// the logger has no internal locking and assumes the same. +// ============================================================================= + +static const bool ggml_cuda_perf_logger_enabled = + (getenv("GGML_CUDA_PERF_LOGGER") != nullptr); + +class ggml_cuda_perf_logger { +public: + ggml_cuda_perf_logger() = default; + + ~ggml_cuda_perf_logger() { + // Flush any pending data — but DO NOT destroy events here. + // + // This is a Meyers-singleton; its destructor runs at static + // destruction time, AFTER main() returns and (often) after the + // CUDA runtime's own static destructors have run. Calling + // cudaEventDestroy on a torn-down driver can crash the process + // on exit. Letting the OS reclaim the events is safe — this + // logger is opt-in via env var, the leaked memory is process- + // lifetime regardless, and the event pool is bounded. If a + // long-running daemon ever wants to reset the logger mid-run, + // expose an explicit reset() that's called while CUDA is still + // alive. + if (next_slot > 0 || !agg.empty()) { + print_and_clear(); + } + } + + // RAII helper: records the start event on construction; the consumer + // sets the resolved label via set_label() (after the dispatch picks + // a fusion / fallback branch); destruction records the end event. + class scope { + public: + scope(ggml_cuda_perf_logger * l, const ggml_tensor * n, cudaStream_t s) + : logger_(l), stream_(s), node_(n) { + slot_ = logger_ ? logger_->begin(stream_) : -1; + } + ~scope() { + if (logger_ && slot_ >= 0) { + logger_->end(stream_, slot_, node_, fusion_label_, n_fused_); + } + } + // Call this from inside the dispatch site to override the default + // per-op label with a fusion-specific one, e.g. + // "RMS_NORM+MUL+ADD" + n=3. + void set_label(const char * label, int n_fused = 1) { + fusion_label_ = label; + n_fused_ = n_fused; + } + private: + ggml_cuda_perf_logger * logger_ = nullptr; + cudaStream_t stream_ = nullptr; + const ggml_tensor * node_ = nullptr; + const char * fusion_label_ = nullptr; + int n_fused_ = 1; + int slot_ = -1; + }; + + // Force a synchronize+flush; called at the end of every + // ggml_backend_cuda_graph_compute when the env var is set. + void flush_and_print(cudaStream_t stream) { + if (next_slot == 0) return; + // Wait for all recorded events to fire (essential before + // cudaEventElapsedTime, and before we re-use the same slots). + (void)cudaStreamSynchronize(stream); + for (int i = 0; i < next_slot; ++i) { + // Skip slots where event creation or recording failed. + if (!ev_starts[i] || !ev_ends[i]) continue; + float ms = 0.0f; + cudaError_t st = cudaEventElapsedTime(&ms, ev_starts[i], ev_ends[i]); + if (st != cudaSuccess) { + continue; + } + const uint64_t ns = (uint64_t)(ms * 1e6); + entry & e = agg[ev_names[i]]; + e.total_ns += ns; + e.count += 1; + } + next_slot = 0; + print_and_clear(); + } + +private: + struct entry { + uint64_t total_ns = 0; + uint64_t count = 0; + }; + + int begin(cudaStream_t stream) { + ensure_capacity(next_slot + 1); + // Defensive: if cudaEventCreate failed earlier (e.g. OOM), + // ev_starts[slot] is the zero-init default — recording on that + // would error at runtime. Skip the slot in that case; + // flush_and_print() already silently drops slots whose + // cudaEventElapsedTime call returns an error, so this composes + // cleanly. + if (ev_starts[next_slot] && ev_ends[next_slot]) { + cudaEventRecord(ev_starts[next_slot], stream); + } + return next_slot; + } + + void end(cudaStream_t stream, int slot, const ggml_tensor * node, + const char * fusion_label, int n_fused) { + if (ev_starts[slot] && ev_ends[slot]) { + cudaEventRecord(ev_ends[slot], stream); + } + ev_names[slot] = make_label(node, fusion_label, n_fused); + next_slot = slot + 1; + } + + void ensure_capacity(int needed) { + if ((int)ev_starts.size() >= needed) return; + const int target = std::max(needed, (int)ev_starts.size() * 2); + const int prev = (int)ev_starts.size(); + ev_starts.resize(target, nullptr); + ev_ends.resize(target, nullptr); + ev_names.resize(target); + for (int i = prev; i < target; ++i) { + // If create fails (e.g. OOM under load), leave the slot null + // and skip recording on it; flush_and_print() tolerates + // failed elapsed-time queries. Don't abort — this is opt-in + // diagnostic code, not on the hot path. + if (cudaEventCreate(&ev_starts[i]) != cudaSuccess) { + ev_starts[i] = nullptr; + } + if (cudaEventCreate(&ev_ends[i]) != cudaSuccess) { + ev_ends[i] = nullptr; + } + } + } + + static std::string make_label(const ggml_tensor * node, + const char * fusion_label, + int n_fused) { + std::string s; + if (fusion_label) { + s = fusion_label; + s += " ("; + s += std::to_string(n_fused); + s += ") "; + } + if (!node) { + s += ""; + return s; + } + if (node->op == GGML_OP_UNARY) { + s += ggml_unary_op_name(ggml_get_unary_op(node)); + } else { + s += ggml_op_name(node->op); + } + // Append dtype + shape. Mirrors ggml-vulkan's vk_perf_logger + // encoding where possible (so cross-backend diffs of perf tables + // stay aligned). + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const int64_t m = node->ne[0]; + const int64_t n = node->ne[1]; + const int64_t k = node->src[1] ? node->src[1]->ne[0] : 0; + const int64_t batch = node->ne[2] * node->ne[3]; + // ggml-vulkan adds "_VEC" suffix when n is small; do the same + // for grep parity. 16 is the threshold ggml-vulkan uses by + // default for the small-N matmul-vec path. + if (n <= 16) { + s += "_VEC"; + } + s += " "; + if (node->src[0]) s += ggml_type_name(node->src[0]->type); + s += " m=" + std::to_string(m); + s += " n=" + std::to_string(n); + s += " k=" + std::to_string(k); + if (batch > 1) s += " batch=" + std::to_string(batch); + } else { + s += " ("; + s += std::to_string(node->ne[0]); + for (int d = 1; d < GGML_MAX_DIMS; ++d) { + s += ","; + s += std::to_string(node->ne[d]); + } + s += ")"; + } + return s; + } + + // Print the current frame's aggregated timings then clear the + // accumulator. Each ggml_backend_cuda_graph_compute call therefore + // produces a self-contained "CUDA Timings:" block (matching how + // ggml-vulkan's vk_perf_logger::print_timings clears `timings` and + // `flops` after every print). Cross-call cumulation isn't useful: + // mixing prompt-phase (large n) and step-phase (n=1) op aggregates + // would produce confusing tables. + void print_and_clear() { + if (agg.empty()) return; + std::vector> rows(agg.begin(), agg.end()); + std::sort(rows.begin(), rows.end(), + [](auto & a, auto & b){ return a.second.total_ns > b.second.total_ns; }); + std::fprintf(stderr, "----------------\nCUDA Timings:\n"); + uint64_t total_all = 0; + for (const auto & [name, e] : rows) { + const double avg_us = e.count ? (double)e.total_ns / (double)e.count / 1000.0 : 0.0; + std::fprintf(stderr, "%s: %llu x %.3f us = %.3f us\n", + name.c_str(), + (unsigned long long)e.count, + avg_us, + (double)e.total_ns / 1000.0); + total_all += e.total_ns; + } + std::fprintf(stderr, "Total time: %.3f us.\n", (double)total_all / 1000.0); + agg.clear(); + } + + std::unordered_map agg; + + // Per-frame event pool. Resets next_slot=0 at every flush. + std::vector ev_starts; + std::vector ev_ends; + std::vector ev_names; + int next_slot = 0; +}; + +// Single process-wide instance. Constructed lazily on first reference +// (Meyers-singleton) so we never pay event-pool setup cost when the env +// var isn't set. Storage is a function-local static; lifetime is "until +// process exit". +static ggml_cuda_perf_logger & g_cuda_perf_logger_get() { + static ggml_cuda_perf_logger inst; + return inst; +} +// Convenience reference that the dispatch loop uses. Cheap because +// g_cuda_perf_logger_get() is a single-load function-local static. +#define g_cuda_perf_logger (g_cuda_perf_logger_get()) + [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails @@ -3762,6 +4016,19 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + // ----------------------------------------------------- + // Per-op timing scope (no-op unless GGML_CUDA_PERF_LOGGER=1). + // Records cudaEventRecord(start) here; the matching end + // event is recorded by ~scope() at the bottom of the + // current dispatch — including before any `continue` taken + // by the fusion fast-paths below. Fusion-specific labels + // (e.g. "MUL_MAT+ADD") could be set inline by .set_label() + // next to each fused dispatch site if desired. + // ----------------------------------------------------- + ggml_cuda_perf_logger * perf_logger_ptr = + ggml_cuda_perf_logger_enabled ? &g_cuda_perf_logger : nullptr; + ggml_cuda_perf_logger::scope perf_scope(perf_logger_ptr, node, cuda_ctx->stream()); + // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { @@ -4010,6 +4277,98 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + // 3-op fusion: MUL_MAT + ADD(bias) + ADD(residual) + // + // Mirrors ggml-vulkan's MUL_MAT_ADD_ADD shader. The + // pattern is `((mat * y) + bias) + residual`, common in + // transformer attention-output and FFN-output blocks + // where the projection is followed by a bias add and a + // residual connection. Without the 3-op fusion, CUDA + // runs three separate kernels per such block; folding + // the residual ADD into the matmul-vec writeback saves + // the launch overhead of one stand-alone GGML_OP_ADD + // per residual. + // + // Placed above the 2-op {MUL_MAT, ADD} fusion below so + // the greedy match prefers the larger fusion when both + // apply (ggml_can_fuse already enforces that the + // intermediate node has no other consumers). + // + // Only MUL_MAT (not MUL_MAT_ID) is handled — the + // residual ADD pattern doesn't apply to MoE expert + // routing in any model the author has seen, and the + // host-side detection logic for ADD_ID would need a + // different path (residual would be a tensor index + // rather than a direct ADD source). + fused_mul_mat_vec = false; + fused_node_count = 0; + if (ggml_can_fuse(cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + ggml_tensor * residual_node = cgraph->nodes[i + 2]; + + // bias_node consumes (mm_node, bias_tensor) in some + // order; pick the one that's not mm_node. + ggml_tensor * bias_tensor = nullptr; + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } + // residual_node must consume (bias_node, residual_tensor). + ggml_tensor * residual_tensor = nullptr; + if (bias_tensor) { + if (residual_node->src[0] == bias_node) { + residual_tensor = residual_node->src[1]; + } else if (residual_node->src[1] == bias_node) { + residual_tensor = residual_node->src[0]; + } + } + + // No broadcasting on either ADD — same constraint + // as the 2-op fusion below. This skips prompt- + // phase patterns where bias is [N] and mm output + // is [N, T]; those fall through to plain dispatch. + const bool no_bias_broadcast = + bias_tensor && + ggml_are_same_shape(bias_node->src[0], bias_node->src[1]); + const bool no_residual_broadcast = + residual_tensor && + ggml_are_same_shape(residual_node->src[0], residual_node->src[1]); + + if (bias_tensor && residual_tensor && no_bias_broadcast && no_residual_broadcast) { + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + // Use the 3-op fused kernel only if the 2-op + // fusion path would have been chosen anyway + // (matmul-vec regime, n=1 dst, non-Pascal + // arch). Larger-N matmul falls through. + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + fusion_data.x_residual = residual_tensor; + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, residual_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + } else if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + fusion_data.x_residual = residual_tensor; + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, residual_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + } + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } + fused_mul_mat_vec = false; fused_node_count = 0; @@ -4241,6 +4600,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); + // Flush per-op timing data after each compute_graph call when the + // GGML_CUDA_PERF_LOGGER env var is set. Mirrors ggml-vulkan's + // post-compute flush. No-op when env var is unset (next_slot==0 + // and agg.empty() short-circuit inside flush_and_print). + if (ggml_cuda_perf_logger_enabled) { + g_cuda_perf_logger.flush_and_print(cuda_ctx->stream()); + } + return GGML_STATUS_SUCCESS; } diff --git a/src/ggml-cuda/mmvf.cu b/src/ggml-cuda/mmvf.cu index d914720242..a7e32c1870 100644 --- a/src/ggml-cuda/mmvf.cu +++ b/src/ggml-cuda/mmvf.cu @@ -50,15 +50,18 @@ static __global__ void mul_mat_vec_f( bool use_gate = false; bool use_bias = false; bool use_gate_bias = false; + bool use_residual = false; ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; const T * gate_x = nullptr; const float * x_bias = nullptr; const float * gate_bias = nullptr; + const float * x_residual = nullptr; if constexpr (has_fusion) { use_gate = fusion.gate != nullptr; use_bias = fusion.x_bias != nullptr; use_gate_bias = fusion.gate_bias != nullptr; + use_residual = fusion.x_residual != nullptr; glu_op = fusion.glu_op; if (use_gate) { @@ -73,6 +76,13 @@ static __global__ void mul_mat_vec_f( } else { use_gate_bias = false; } + if (use_residual) { + // Residual is added AFTER bias and (if any) GLU — matches the + // chatterbox graph shape `((mm * y) + bias) + residual` and + // ggml-vulkan's MUL_MAT_ADD_ADD shader. Same shape rules as + // x_bias (must match dst, no broadcasting). + x_residual = static_cast(fusion.x_residual); + } } if (use_gate) { @@ -88,6 +98,13 @@ static __global__ void mul_mat_vec_f( if (use_gate_bias) { gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } + if (use_residual) { + // Residual must be offset into the same (sample, channel) slice + // as the bias-add output it gets summed with. Same shape rules + // as x_bias (no broadcasting; host-side fusion-detection in + // ggml_backend_cuda_graph_compute enforces this). + x_residual += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } } const float2 * y2 = (const float2 *) y; @@ -364,6 +381,10 @@ static __global__ void mul_mat_vec_f( break; } } + // Residual is added last — see x_residual comment near declaration. + if (use_residual) { + value += x_residual[tid*stride_col_dst + row]; + } } dst[tid*stride_col_dst + row] = value; @@ -667,6 +688,14 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); fusion_local.gate_bias = fusion->gate_bias->data; } + if (fusion->x_residual) { + // Same shape rules as x_bias — host-side fusion-detection + // logic enforces no broadcasting. + GGML_ASSERT(fusion->x_residual->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_residual->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_residual->ne[1] == src0->ne[2]); + fusion_local.x_residual = fusion->x_residual->data; + } fusion_local.glu_op = fusion->glu_op; } diff --git a/src/ggml-cuda/mmvq.cu b/src/ggml-cuda/mmvq.cu index 8f55cace1a..cc751a4442 100644 --- a/src/ggml-cuda/mmvq.cu +++ b/src/ggml-cuda/mmvq.cu @@ -429,24 +429,29 @@ static __global__ void mul_mat_vec_q( bool use_gate = false; bool use_bias = false; bool use_gate_bias = false; + bool use_residual = false; const void * vgate = nullptr; const float * x_bias = nullptr; const float * gate_bias = nullptr; + const float * x_residual = nullptr; ggml_glu_op active_glu; if constexpr (has_fusion) { - use_gate = fusion.gate != nullptr; - use_bias = fusion.x_bias != nullptr; - use_gate_bias = fusion.gate_bias != nullptr && use_gate; + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + use_residual = fusion.x_residual != nullptr; vgate = fusion.gate; x_bias = (const float *) fusion.x_bias; gate_bias = (const float *) fusion.gate_bias; + x_residual = (const float *) fusion.x_residual; active_glu = fusion.glu_op; } - float x_biases[ncols_dst] = { 0.0f }; - float gate_biases[ncols_dst] = { 0.0f }; + float x_biases[ncols_dst] = { 0.0f }; + float gate_biases[ncols_dst] = { 0.0f }; + float x_residuals[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { @@ -461,6 +466,20 @@ static __global__ void mul_mat_vec_q( } } } + if (use_residual) { + // Residual is added AFTER bias and after the gate's GLU application + // (matches ggml-vulkan's MUL_MAT_ADD_ADD shader). Residual must be + // F32 with the same shape as dst — broadcasting is rejected in the + // host-side fusion-detection logic before we get here. + x_residual = x_residual + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 && + (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + x_residuals[j] = x_residual[j * stride_col_dst + threadIdx.x]; + } + } + } if (use_gate_bias) { gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 && @@ -580,13 +599,20 @@ static __global__ void mul_mat_vec_q( break; } } + // Residual is added last — matches the chatterbox graph + // shape `((mm * y) + bias) + residual` and Vulkan's + // MUL_MAT_ADD_ADD shader execution order. Skipped when + // x_residual was null (use_residual == false). + if (use_residual) { + result += x_residuals[j]; + } } dst[j*stride_col_dst + threadIdx.x] = result; } } if constexpr (!has_fusion) { - GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate); + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, use_residual, active_glu, gate_bias, x_bias, x_residual, tmp_gate); } } @@ -1074,6 +1100,16 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); fusion_local.gate_bias = fusion->gate_bias->data; } + if (fusion->x_residual) { + // Residual must be F32 and exactly match the dst tensor in + // shape — no broadcasting (the host-side fusion-detection + // logic in ggml_backend_cuda_graph_compute already enforces + // this). Mirrors the x_bias asserts above. + GGML_ASSERT(fusion->x_residual->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_residual->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_residual->ne[1] == src0->ne[2]); + fusion_local.x_residual = fusion->x_residual->data; + } fusion_local.glu_op = fusion->glu_op; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 828a9c14a4..7ebfeb2a17 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7679,6 +7679,18 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); + // HiFT-vocoder-realistic shapes (exercise the warp-cooperative + // CUDA kernel: IC > 32 multi-warp accumulation, kernel-vs-stride + // ratios > 1 hitting the inner-loop unroll path). These cases + // pass on both the legacy scalar kernel and the warp-cooperative + // variant; they're a regression guard so future kernel changes + // get caught at HiFT-realistic scale, not just on tiny shapes. + test_cases.emplace_back(new test_conv_transpose_1d({303, 80, 1, 1}, {16, 64, 80, 1}, 8, 0, 1)); // HiFT layer 1 + test_cases.emplace_back(new test_conv_transpose_1d({150, 128, 1, 1}, {16, 64, 128, 1}, 8, 0, 1)); // HiFT layer 2 + test_cases.emplace_back(new test_conv_transpose_1d({100, 64, 1, 1}, { 8, 32, 64, 1}, 4, 0, 1)); // HiFT layer 3 (smaller s0) + test_cases.emplace_back(new test_conv_transpose_1d({ 50, 32, 1, 1}, { 4, 16, 32, 1}, 2, 0, 1)); // HiFT layer 4 (small) + test_cases.emplace_back(new test_conv_transpose_1d({200, 64, 1, 1}, {32, 32, 64, 1}, 8, 0, 1)); // K > s0 (multi-touch) + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type)); test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type)); @@ -8636,6 +8648,23 @@ static std::vector> make_test_cases_eval() { use_id, 16, 8, b, with_bias, with_gate)); test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256, use_id, 16, 8, b, with_bias, with_gate, {1, 1})); + // Larger / decoder-realistic shapes — exercise the + // mul_mat_vec + add(bias) + GLU + add(residual) fusion + // path at scales typical of autoregressive transformer + // FFN blocks (m=1, k>=1024, n in [1024, 3072]). These + // pass on backends that fuse only mul_mat+add and run + // the residual ADD separately (current ggml-cuda) and + // also on backends that fuse all three (ggml-vulkan + // since MUL_MAT_ADD_ADD shaders, ggml-cuda after this + // PR). No batch dim variation here — that adds a + // 2-3x test-suite runtime for shapes already covered + // by the {4,2}/{1,1} sweep above. + if (!use_id) { + test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 1024, 1024, + false, 1, 1, false, with_bias, with_gate, {1, 1})); + test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 3072, 1024, + false, 1, 1, false, with_bias, with_gate, {1, 1})); + } } } }